From c8a2782df8044b0e528e35a9576e3303b17aa30c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 12 Jul 2024 18:48:40 -0700 Subject: [PATCH] docs(pass_through.md): add doc on creating custom chat endpoints on proxy Allows developers to call proxy with anthropic sdk/boto3/etc. --- docs/my-website/docs/proxy/pass_through.md | 148 ++++++++++++- litellm/proxy/_new_secret_config.yaml | 8 +- litellm/proxy/auth/user_api_key_auth.py | 16 +- .../pass_through_endpoints.py | 208 +++++++++++++++++- litellm/tests/test_pass_through_endpoints.py | 49 +++++ 5 files changed, 419 insertions(+), 10 deletions(-) diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md index 1348a2fc1c..82a374503b 100644 --- a/docs/my-website/docs/proxy/pass_through.md +++ b/docs/my-website/docs/proxy/pass_through.md @@ -217,4 +217,150 @@ general_settings: * `accept` *string*: The expected response format from the server. * `LANGFUSE_PUBLIC_KEY` *string*: Your Langfuse account public key - only set this when forwarding to Langfuse. * `LANGFUSE_SECRET_KEY` *string*: Your Langfuse account secret key - only set this when forwarding to Langfuse. - * `` *string*: Pass any custom header key/value pair \ No newline at end of file + * `` *string*: Pass any custom header key/value pair + + +## Custom Chat Endpoints (Anthropic/Bedrock/Vertex) + +Allow developers to call the proxy with Anthropic sdk/boto3/etc. + +Use our [Anthropic Adapter](https://github.com/BerriAI/litellm/blob/fd743aaefd23ae509d8ca64b0c232d25fe3e39ee/litellm/adapters/anthropic_adapter.py#L50) for reference + +### 1. Write an Adapter + +Translate the request/response from your custom API schema to the OpenAI schema and back. + +This is used internally to do Logging, Guardrails, etc. in a consistent format. + +For provider-specific params 👉 [**Provider-Specific Params**](../completion/provider_specific_params.md) + +```python +from litellm import adapter_completion +import litellm +from litellm import ChatCompletionRequest, verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.llms.anthropic import AnthropicMessagesRequest, AnthropicResponse +import os + +# What is this? +## Translates OpenAI call to Anthropic `/v1/messages` format +import json +import os +import traceback +import uuid +from typing import Literal, Optional + +import dotenv +import httpx +from pydantic import BaseModel + + +################### +# CUSTOM ADAPTER ## +################### + +class AnthropicAdapter(CustomLogger): + def __init__(self) -> None: + super().__init__() + + def translate_completion_input_params( + self, kwargs + ) -> Optional[ChatCompletionRequest]: + """ + - translate params, where needed + - pass rest, as is + """ + request_body = AnthropicMessagesRequest(**kwargs) # type: ignore + + translated_body = litellm.AnthropicConfig().translate_anthropic_to_openai( + anthropic_message_request=request_body + ) + + return translated_body + + def translate_completion_output_params( + self, response: litellm.ModelResponse + ) -> Optional[AnthropicResponse]: + + return litellm.AnthropicConfig().translate_openai_response_to_anthropic( + response=response + ) + + def translate_completion_output_params_streaming(self) -> Optional[BaseModel]: + return super().translate_completion_output_params_streaming() + + +anthropic_adapter = AnthropicAdapter() + +########### +# TEST IT # +########### + +## register CUSTOM ADAPTER +litellm.adapters = [{"id": "anthropic", "adapter": anthropic_adapter}] + +## set ENV variables +os.environ["OPENAI_API_KEY"] = "your-openai-key" +os.environ["COHERE_API_KEY"] = "your-cohere-key" + +messages = [{ "content": "Hello, how are you?","role": "user"}] + +# openai call +response = adapter_completion(model="gpt-3.5-turbo", messages=messages, adapter_id="anthropic") + +# cohere call +response = adapter_completion(model="command-nightly", messages=messages, adapter_id="anthropic") +print(response) +``` + +### 2. Create new endpoint + +We pass the custom callback class defined in Step1 to the config.yaml. Set callbacks to python_filename.logger_instance_name + +In the config below, we pass + +python_filename: `custom_callbacks.py` +logger_instance_name: `anthropic_adapter`. This is defined in Step 1 + +`target: custom_callbacks.proxy_handler_instance` + +```yaml +model_list: + - model_name: my-fake-claude-endpoint + litellm_params: + model: gpt-3.5-turbo + api_key: os.environ/OPENAI_API_KEY + + +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/messages" # route you want to add to LiteLLM Proxy Server + target: custom_callbacks.anthropic_adapter # Adapter to use for this route + headers: + litellm_user_api_key: "x-api-key" # Field in headers, containing LiteLLM Key +``` + +### 3. Test it! + +**Start proxy** + +```bash +litellm --config /path/to/config.yaml +``` + +**Curl** + +```bash +curl --location 'http://0.0.0.0:4000/v1/messages' \ +-H 'x-api-key: sk-1234' \ +-H 'anthropic-version: 2023-06-01' \ # ignored +-H 'content-type: application/json' \ +-D '{ + "model": "my-fake-claude-endpoint", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello, world"} + ] +}' +``` \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a8c9e88233..cf4a823c39 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -13,4 +13,10 @@ model_list: general_settings: alerting: ["slack"] - alerting_threshold: 10 \ No newline at end of file + alerting_threshold: 10 + master_key: sk-1234 + pass_through_endpoints: + - path: "/v1/test-messages" # route you want to add to LiteLLM Proxy Server + target: litellm.adapters.anthropic_adapter.anthropic_adapter # URL this route should forward requests to + headers: # headers to forward to this URL + litellm_user_api_key: "x-my-test-key" \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 4b931a2726..5b3d86d557 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -115,6 +115,12 @@ async def user_api_key_auth( ) try: + route: str = request.url.path + + pass_through_endpoints: Optional[List[dict]] = general_settings.get( + "pass_through_endpoints", None + ) + if isinstance(api_key, str): passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) @@ -125,6 +131,14 @@ async def user_api_key_auth( elif isinstance(anthropic_api_key_header, str): api_key = anthropic_api_key_header + elif pass_through_endpoints is not None: + for endpoint in pass_through_endpoints: + if endpoint.get("path", "") == route: + headers: Optional[dict] = endpoint.get("headers", None) + if headers is not None: + header_key: str = headers.get("litellm_user_api_key", "") + if request.headers.get(key=header_key) is not None: + api_key = request.headers.get(key=header_key) parent_otel_span: Optional[Span] = None if open_telemetry_logger is not None: parent_otel_span = open_telemetry_logger.tracer.start_span( @@ -163,8 +177,6 @@ async def user_api_key_auth( detail="Access forbidden: IP address not allowed.", ) - route: str = request.url.path - if ( route in LiteLLMRoutes.public_routes.value or route_in_additonal_public_routes(current_route=route) diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 218032e012..b13e9834a2 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -1,4 +1,6 @@ import ast +import asyncio +import json import traceback from base64 import b64encode @@ -16,7 +18,8 @@ from fastapi.responses import StreamingResponse import litellm from litellm._logging import verbose_proxy_logger -from litellm.proxy._types import ProxyException +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import ProxyException, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth async_client = httpx.AsyncClient() @@ -24,7 +27,7 @@ async_client = httpx.AsyncClient() async def set_env_variables_in_header(custom_headers: dict): """ - checks if nay headers on config.yaml are defined as os.environ/COHERE_API_KEY etc + checks if any headers on config.yaml are defined as os.environ/COHERE_API_KEY etc only runs for headers defined on config.yaml @@ -72,6 +75,171 @@ async def set_env_variables_in_header(custom_headers: dict): return headers +async def chat_completion_pass_through_endpoint( + fastapi_response: Response, + request: Request, + adapter_id: str, + user_api_key_dict: UserAPIKeyAuth, +): + from litellm.proxy.proxy_server import ( + add_litellm_data_to_request, + general_settings, + get_custom_headers, + llm_router, + proxy_config, + proxy_logging_obj, + user_api_base, + user_max_tokens, + user_model, + user_request_timeout, + user_temperature, + version, + ) + + data = {} + try: + body = await request.body() + body_str = body.decode() + try: + data = ast.literal_eval(body_str) + except Exception: + data = json.loads(body_str) + + data["adapter_id"] = adapter_id + + verbose_proxy_logger.debug( + "Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)), + ) + data["model"] = ( + general_settings.get("completion_model", None) # server default + or user_model # model name passed via cli args + or data["model"] # default passed in http request + ) + if user_model: + data["model"] = user_model + + data = await add_litellm_data_to_request( + data=data, # type: ignore + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) + + # override with user settings, these are params passed via cli + if user_temperature: + data["temperature"] = user_temperature + if user_request_timeout: + data["request_timeout"] = user_request_timeout + if user_max_tokens: + data["max_tokens"] = user_max_tokens + if user_api_base: + data["api_base"] = user_api_base + + ### MODEL ALIAS MAPPING ### + # check if model name in model alias map + # get the actual model name + if data["model"] in litellm.model_alias_map: + data["model"] = litellm.model_alias_map[data["model"]] + + ### CALL HOOKS ### - modify incoming data before calling the model + data = await proxy_logging_obj.pre_call_hook( # type: ignore + user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion" + ) + + ### ROUTE THE REQUESTs ### + router_model_names = llm_router.model_names if llm_router is not None else [] + # skip router if user passed their key + if "api_key" in data: + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in router_model_names + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): # model set in model_group_alias + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None and data["model"] in llm_router.deployment_names + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task( + llm_router.aadapter_completion(**data, specific_deployment=True) + ) + elif ( + llm_router is not None and data["model"] in llm_router.get_model_ids() + ): # model in router model list + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif ( + llm_router is not None + and data["model"] not in router_model_names + and llm_router.default_deployment is not None + ): # model in router deployments, calling a specific deployment on the router + llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) + elif user_model is not None: # `litellm --model ` + llm_response = asyncio.create_task(litellm.aadapter_completion(**data)) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "completion: Invalid model name passed in model=" + + data.get("model", "") + }, + ) + + # Await the llm_response task + response = await llm_response + + hidden_params = getattr(response, "_hidden_params", {}) or {} + model_id = hidden_params.get("model_id", None) or "" + cache_key = hidden_params.get("cache_key", None) or "" + api_base = hidden_params.get("api_base", None) or "" + response_cost = hidden_params.get("response_cost", None) or "" + + ### ALERTING ### + asyncio.create_task( + proxy_logging_obj.update_request_status( + litellm_call_id=data.get("litellm_call_id", ""), status="success" + ) + ) + + verbose_proxy_logger.debug("final response: %s", response) + + fastapi_response.headers.update( + get_custom_headers( + user_api_key_dict=user_api_key_dict, + model_id=model_id, + cache_key=cache_key, + api_base=api_base, + version=version, + response_cost=response_cost, + ) + ) + + verbose_proxy_logger.info("\nResponse from Litellm:\n{}".format(response)) + return response + except Exception as e: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data + ) + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.completion(): Exception occured - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + error_msg = f"{str(e)}" + raise ProxyException( + message=getattr(e, "message", error_msg), + type=getattr(e, "type", "None"), + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", 500), + ) + + async def pass_through_request(request: Request, target: str, custom_headers: dict): try: @@ -106,7 +274,7 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) except Exception as e: verbose_proxy_logger.error( - "litellm.proxy.proxy_server.pass through endpoint(): Exception occured - {}".format( + "litellm.proxy.proxy_server.pass_through_endpoint(): Exception occured - {}".format( str(e) ) ) @@ -128,9 +296,37 @@ async def pass_through_request(request: Request, target: str, custom_headers: di ) -def create_pass_through_route(endpoint, target, custom_headers=None): - async def endpoint_func(request: Request): - return await pass_through_request(request, target, custom_headers) +def create_pass_through_route(endpoint, target: str, custom_headers=None): + # check if target is an adapter.py or a url + import uuid + + from litellm.proxy.utils import get_instance_fn + + try: + if isinstance(target, CustomLogger): + adapter = target + else: + adapter = get_instance_fn(value=target) + adapter_id = str(uuid.uuid4()) + litellm.adapters = [{"id": adapter_id, "adapter": adapter}] + + async def endpoint_func( + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ): + return await chat_completion_pass_through_endpoint( + fastapi_response=fastapi_response, + request=request, + adapter_id=adapter_id, + user_api_key_dict=user_api_key_dict, + ) + + except Exception: + verbose_proxy_logger.warning("Defaulting to target being a url.") + + async def endpoint_func(request: Request): # type: ignore + return await pass_through_request(request, target, custom_headers) return endpoint_func diff --git a/litellm/tests/test_pass_through_endpoints.py b/litellm/tests/test_pass_through_endpoints.py index 0f234dfa8b..43543ecc76 100644 --- a/litellm/tests/test_pass_through_endpoints.py +++ b/litellm/tests/test_pass_through_endpoints.py @@ -83,3 +83,52 @@ async def test_pass_through_endpoint_rerank(client): # Assert the response assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_pass_through_endpoint_anthropic(client): + import litellm + from litellm import Router + from litellm.adapters.anthropic_adapter import anthropic_adapter + + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": os.getenv("OPENAI_API_KEY"), + "mock_response": "Hey, how's it going?", + }, + } + ] + ) + + setattr(litellm.proxy.proxy_server, "llm_router", router) + + # Define a pass-through endpoint + pass_through_endpoints = [ + { + "path": "/v1/test-messages", + "target": anthropic_adapter, + "headers": {"litellm_user_api_key": "my-test-header"}, + } + ] + + # Initialize the pass-through endpoint + await initialize_pass_through_endpoints(pass_through_endpoints) + + _json_data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Who are you?"}], + } + + # Make a request to the pass-through endpoint + response = client.post( + "/v1/test-messages", json=_json_data, headers={"my-test-header": "my-test-key"} + ) + + print("JSON response: ", _json_data) + + # Assert the response + assert response.status_code == 200