{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Databricks Notebook with MLFlow AutoLogging for LiteLLM Proxy calls\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "5e2812ed-8000-4793-b090-49a31464d810", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "%pip install -U -qqqq databricks-agents mlflow langchain==0.3.1 langchain-core==0.3.6 " ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "52530b37-1860-4bba-a6c1-723de83bc58f", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "%pip install \"langchain-openai<=0.3.1\"" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "43c6f4b1-e2d5-431c-b1a2-b97df7707d59", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "# Before logging this chain using the driver notebook, you must comment out this line.\n", "dbutils.library.restartPython() " ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "88eb8dd7-16b1-480b-aa70-cd429ef87159", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "import mlflow\n", "from operator import itemgetter\n", "from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.prompts import PromptTemplate\n", "from langchain_core.runnables import RunnableLambda\n", "from langchain_databricks import ChatDatabricks\n", "from langchain_openai import ChatOpenAI" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "f0fdca8f-6f6f-407c-ad4a-0d5a2778728e", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "import mlflow\n", "mlflow.langchain.autolog()" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "2ef67315-e468-4d60-a318-98c2cac75bc4", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "# These helper functions parse the `messages` array.\n", "\n", "# Return the string contents of the most recent message from the user\n", "def extract_user_query_string(chat_messages_array):\n", " return chat_messages_array[-1][\"content\"]\n", "\n", "\n", "# Return the chat history, which is is everything before the last question\n", "def extract_chat_history(chat_messages_array):\n", " return chat_messages_array[:-1]" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "17708467-1976-48bd-94a0-8c7895cfae3b", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "model = ChatOpenAI(\n", " openai_api_base=\"LITELLM_PROXY_BASE_URL\", # e.g.: http://0.0.0.0:4000\n", " model = \"gpt-3.5-turbo\", # LITELLM 'model_name'\n", " temperature=0.1, \n", " api_key=\"LITELLM_PROXY_API_KEY\" # e.g.: \"sk-1234\"\n", ")" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "a5f2c2af-82f7-470d-b559-47b67fb00cda", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "############\n", "# Prompt Template for generation\n", "############\n", "prompt = PromptTemplate(\n", " template=\"You are a hello world bot. Respond with a reply to the user's question that is fun and interesting to the user. User's question: {question}\",\n", " input_variables=[\"question\"],\n", ")\n", "\n", "############\n", "# FM for generation\n", "# ChatDatabricks accepts any /llm/v1/chat model serving endpoint\n", "############\n", "model = ChatDatabricks(\n", " endpoint=\"databricks-dbrx-instruct\",\n", " extra_params={\"temperature\": 0.01, \"max_tokens\": 500},\n", ")\n", "\n", "\n", "############\n", "# Simple chain\n", "############\n", "# The framework requires the chain to return a string value.\n", "chain = (\n", " {\n", " \"question\": itemgetter(\"messages\")\n", " | RunnableLambda(extract_user_query_string),\n", " \"chat_history\": itemgetter(\"messages\") | RunnableLambda(extract_chat_history),\n", " }\n", " | prompt\n", " | model\n", " | StrOutputParser()\n", ")" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "366edd90-62a1-4d6f-8a65-0211fb24ca02", "showTitle": false, "title": "" } }, "outputs": [ { "data": { "text/plain": [ "'Hello there! I\\'m here to help with your questions. Regarding your query about \"rag,\" it\\'s not something typically associated with a \"hello world\" bot, but I\\'m happy to explain!\\n\\nRAG, or Remote Angular GUI, is a tool that allows you to create and manage Angular applications remotely. It\\'s a way to develop and test Angular components and applications without needing to set up a local development environment. This can be particularly useful for teams working on distributed systems or for developers who prefer to work in a cloud-based environment.\\n\\nI hope this explanation of RAG has been helpful and interesting! If you have any other questions or need further clarification, feel free to ask.'" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/databricks.mlflow.trace": "\"tr-ea2226413395413ba2cf52cffc523502\"", "text/plain": [ "Trace(request_id=tr-ea2226413395413ba2cf52cffc523502)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# This is the same input your chain's REST API will accept.\n", "question = {\n", " \"messages\": [\n", " {\n", " \"role\": \"user\",\n", " \"content\": \"what is rag?\",\n", " },\n", " ]\n", "}\n", "\n", "chain.invoke(question)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { "cellMetadata": { "byteLimit": 2048000, "rowLimit": 10000 }, "inputWidgets": {}, "nuid": "5d68e37d-0980-4a02-bf8d-885c3853f6c1", "showTitle": false, "title": "" } }, "outputs": [], "source": [ "mlflow.models.set_model(model=model)" ] } ], "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], "environmentMetadata": null, "language": "python", "notebookMetadata": { "pythonIndentUnit": 4 }, "notebookName": "Untitled Notebook 2024-10-16 19:35:16", "widgets": {} }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }