mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(anthropic_adapter.py): support streaming requests for /v1/messages
endpoint
Fixes https://github.com/BerriAI/litellm/issues/5011
This commit is contained in:
parent
39a98a2882
commit
ac6c39c283
9 changed files with 425 additions and 35 deletions
|
@ -125,7 +125,7 @@ from .llms.vertex_ai_partner import VertexAIPartnerModels
|
|||
from .llms.vertex_httpx import VertexLLM
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .types.llms.openai import HttpxBinaryResponseContent
|
||||
from .types.utils import ChatCompletionMessageToolCall
|
||||
from .types.utils import AdapterCompletionStreamWrapper, ChatCompletionMessageToolCall
|
||||
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
|
@ -515,7 +515,7 @@ def mock_completion(
|
|||
model_response = ModelResponse(stream=stream)
|
||||
if stream is True:
|
||||
# don't try to access stream object,
|
||||
if kwargs.get("acompletion", False) == True:
|
||||
if kwargs.get("acompletion", False) is True:
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=async_mock_completion_streaming_obj(
|
||||
model_response, mock_response=mock_response, model=model, n=n
|
||||
|
@ -524,13 +524,14 @@ def mock_completion(
|
|||
custom_llm_provider="openai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
response = mock_completion_streaming_obj(
|
||||
model_response,
|
||||
mock_response=mock_response,
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=mock_completion_streaming_obj(
|
||||
model_response, mock_response=mock_response, model=model, n=n
|
||||
),
|
||||
model=model,
|
||||
n=n,
|
||||
custom_llm_provider="openai",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
if n is None:
|
||||
model_response.choices[0].message.content = mock_response # type: ignore
|
||||
else:
|
||||
|
@ -4037,7 +4038,9 @@ def text_completion(
|
|||
###### Adapter Completion ################
|
||||
|
||||
|
||||
async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
||||
async def aadapter_completion(
|
||||
*, adapter_id: str, **kwargs
|
||||
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
|
||||
"""
|
||||
Implemented to handle async calls for adapter_completion()
|
||||
"""
|
||||
|
@ -4056,18 +4059,29 @@ async def aadapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseMode
|
|||
|
||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: ModelResponse = await acompletion(**new_kwargs) # type: ignore
|
||||
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
)
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[
|
||||
Union[BaseModel, AdapterCompletionStreamWrapper]
|
||||
] = None
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
)
|
||||
if isinstance(response, CustomStreamWrapper):
|
||||
translated_response = (
|
||||
translation_obj.translate_completion_output_params_streaming(
|
||||
completion_stream=response
|
||||
)
|
||||
)
|
||||
|
||||
return translated_response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
||||
def adapter_completion(
|
||||
*, adapter_id: str, **kwargs
|
||||
) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]:
|
||||
translation_obj: Optional[CustomLogger] = None
|
||||
for item in litellm.adapters:
|
||||
if item["id"] == adapter_id:
|
||||
|
@ -4082,11 +4096,20 @@ def adapter_completion(*, adapter_id: str, **kwargs) -> Optional[BaseModel]:
|
|||
|
||||
new_kwargs = translation_obj.translate_completion_input_params(kwargs=kwargs)
|
||||
|
||||
response: ModelResponse = completion(**new_kwargs) # type: ignore
|
||||
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore
|
||||
translated_response: Optional[Union[BaseModel, AdapterCompletionStreamWrapper]] = (
|
||||
None
|
||||
)
|
||||
if isinstance(response, ModelResponse):
|
||||
translated_response = translation_obj.translate_completion_output_params(
|
||||
response=response
|
||||
)
|
||||
elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator(response):
|
||||
translated_response = (
|
||||
translation_obj.translate_completion_output_params_streaming(
|
||||
completion_stream=response
|
||||
)
|
||||
)
|
||||
|
||||
return translated_response
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue