feat(anthropic_adapter.py): support streaming requests for /v1/messages endpoint

Fixes https://github.com/BerriAI/litellm/issues/5011
This commit is contained in:
Krrish Dholakia 2024-08-03 20:16:19 -07:00
parent b3aa722ebb
commit 5810708c71
9 changed files with 425 additions and 35 deletions

View file

@ -125,7 +125,7 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -515,7 +515,7 @@ def mock_completion(
model_response = ModelResponse(stream=stream)
if stream is True:
# don't try to access stream object,
if kwargs.get("acompletion", False) == True:
if kwargs.get("acompletion", False) is True:
return CustomStreamWrapper(
completion_stream=async_mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model, n=n
@ -524,13 +524,14 @@ def mock_completion(
custom_llm_provider="openai",
logging_obj=logging,
)
response = mock_completion_streaming_obj(
model_response,
mock_response=mock_response,
return CustomStreamWrapper(
completion_stream=mock_completion_streaming_obj(
model_response, mock_response=mock_response, model=model, n=n
),
model=model,
n=n,
custom_llm_provider="openai",
logging_obj=logging,
)
return response
if n is None:
model_response.choices[0].message.content = mock_response # type: ignore
else:
@ -4037,7 +4038,9 @@ def text_completion(
###### Adapter Completion ################
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
async def aadapter_completion(
*, adapter_id: str, **kwargs
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
"""
Implemented to handle async calls for adapter_completion()
"""
@ -4056,18 +4059,29 @@ async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseMode
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
)
response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore
translated_response: Optional[
Union[BaseModel, AdapterCompletionStreamWrapper]
] = None
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
)
if isinstance(response, CustomStreamWrapper):
translated_response = (
translation_obj.translate_completion_output_params_streaming(
completion_stream=response
)
)
return translated_response
except Exception as e:
raise e
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
def adapter_completion(
*, adapter_id: str, **kwargs
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
translation_obj: Optional[CustomLogger] = None
for item in litellm.adapters:
if item["id"] == adapter_id:
@ -4082,11 +4096,20 @@ def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
response: ModelResponse = completion(**new_kwargs) # type: ignore
translated_response = translation_obj.translate_completion_output_params(
response=response
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
None
)
if isinstance(response, ModelResponse):
translated_response = translation_obj.translate_completion_output_params(
response=response
)
elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator(response):
translated_response = (
translation_obj.translate_completion_output_params_streaming(
completion_stream=response
)
)
return translated_response