mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 04:12:25 +00:00
more fixes and record from the server since server enables more tests
This commit is contained in:
parent
6ebc93de81
commit
9b3a860beb
29 changed files with 2584 additions and 762 deletions
|
|
@ -200,12 +200,12 @@ class ResponseStorage:
|
|||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, method_name=None, **kwargs):
|
||||
async def _patched_inference_method(original_method, self, client_type, method_name=None, *args, **kwargs):
|
||||
global _current_mode, _current_storage
|
||||
|
||||
if _current_mode == "live" or _current_storage is None:
|
||||
# Normal operation
|
||||
return await original_method(self, **kwargs)
|
||||
return await original_method(self, *args, **kwargs)
|
||||
|
||||
# Get base URL and endpoint based on client type
|
||||
if client_type == "openai":
|
||||
|
|
@ -284,7 +284,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
)
|
||||
|
||||
elif _current_mode == "record":
|
||||
response = await original_method(self, **kwargs)
|
||||
response = await original_method(self, *args, **kwargs)
|
||||
|
||||
request_data = {
|
||||
"method": method,
|
||||
|
|
@ -321,7 +321,7 @@ async def _patched_inference_method(original_method, self, client_type, method_n
|
|||
return response
|
||||
|
||||
else:
|
||||
return await original_method(self, **kwargs)
|
||||
raise AssertionError(f"Invalid mode: {_current_mode}")
|
||||
|
||||
|
||||
def patch_inference_clients():
|
||||
|
|
@ -347,14 +347,16 @@ def patch_inference_clients():
|
|||
}
|
||||
|
||||
# Create patched methods for OpenAI client
|
||||
async def patched_chat_completions_create(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["chat_completions_create"], self, "openai", **kwargs)
|
||||
async def patched_chat_completions_create(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["chat_completions_create"], self, "openai", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_completions_create(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["completions_create"], self, "openai", **kwargs)
|
||||
async def patched_completions_create(self, *args, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["completions_create"], self, "openai", *args, **kwargs)
|
||||
|
||||
async def patched_embeddings_create(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", **kwargs)
|
||||
async def patched_embeddings_create(self, *args, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["embeddings_create"], self, "openai", *args, **kwargs)
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
|
|
@ -362,27 +364,33 @@ def patch_inference_clients():
|
|||
AsyncEmbeddings.create = patched_embeddings_create
|
||||
|
||||
# Create patched methods for Ollama client
|
||||
async def patched_ollama_generate(self, **kwargs):
|
||||
async def patched_ollama_generate(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["ollama_generate"], self, "ollama", "generate", **kwargs
|
||||
_original_methods["ollama_generate"], self, "ollama", "generate", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_ollama_chat(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_chat"], self, "ollama", "chat", **kwargs)
|
||||
async def patched_ollama_chat(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["ollama_chat"], self, "ollama", "chat", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_ollama_embed(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_embed"], self, "ollama", "embed", **kwargs)
|
||||
async def patched_ollama_embed(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["ollama_embed"], self, "ollama", "embed", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_ollama_ps(self, **kwargs):
|
||||
logger.info("replay mode: ollama.ps() reporting success")
|
||||
return []
|
||||
async def patched_ollama_ps(self, *args, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_ps"], self, "ollama", "ps", *args, **kwargs)
|
||||
|
||||
async def patched_ollama_pull(self, *args, **kwargs):
|
||||
logger.info("replay mode: ollama.pull() not actually pulling the model")
|
||||
return None
|
||||
return await _patched_inference_method(
|
||||
_original_methods["ollama_pull"], self, "ollama", "pull", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_ollama_list(self, **kwargs):
|
||||
return await _patched_inference_method(_original_methods["ollama_list"], self, "ollama", "list", **kwargs)
|
||||
async def patched_ollama_list(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["ollama_list"], self, "ollama", "list", *args, **kwargs
|
||||
)
|
||||
|
||||
# Apply Ollama patches
|
||||
OllamaAsyncClient.generate = patched_ollama_generate
|
||||
|
|
@ -416,6 +424,7 @@ def unpatch_inference_clients():
|
|||
OllamaAsyncClient.chat = _original_methods["ollama_chat"]
|
||||
OllamaAsyncClient.embed = _original_methods["ollama_embed"]
|
||||
OllamaAsyncClient.ps = _original_methods["ollama_ps"]
|
||||
OllamaAsyncClient.pull = _original_methods["ollama_pull"]
|
||||
OllamaAsyncClient.list = _original_methods["ollama_list"]
|
||||
|
||||
_original_methods.clear()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue