From 8b9748a5b5d2cc2c0c959458115717d71a438054 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 15 May 2026 02:58:20 +0800 Subject: [PATCH] fix: preserve live session id on tool responses --- src/google/adk/flows/llm_flows/functions.py | 18 +++++++++-- .../flows/llm_flows/test_functions_simple.py | 30 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 3c61c15ff3..5897ecbc89 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -625,6 +625,7 @@ async def handle_function_calls_live( tools_dict, agent, streaming_lock, + function_call_event.live_session_id, ) ) for function_call in function_calls @@ -669,6 +670,7 @@ async def _execute_single_function_call_live( tools_dict: dict[str, BaseTool], agent: LlmAgent, streaming_lock: asyncio.Lock, + live_session_id: Optional[str], ) -> Optional[Event]: """Execute a single function call for live mode with thread safety.""" @@ -726,7 +728,11 @@ async def _run_on_tool_error_callbacks( ) if error_response is not None: return __build_response_event( - tool, error_response, tool_context, invocation_context + tool, + error_response, + tool_context, + invocation_context, + live_session_id=live_session_id, ) raise tool_error @@ -823,7 +829,11 @@ async def _run_with_trace(): # Builds the function response event. function_response_event = __build_response_event( - tool, function_response, tool_context, invocation_context + tool, + function_response, + tool_context, + invocation_context, + live_session_id=live_session_id, ) return function_response_event @@ -1108,6 +1118,8 @@ def __build_response_event( function_result: dict[str, object], tool_context: ToolContext, invocation_context: InvocationContext, + *, + live_session_id: Optional[str] = None, ) -> Event: # Specs requires the result to be a dict. if not isinstance(function_result, dict): @@ -1137,6 +1149,7 @@ def __build_response_event( content=content, actions=tool_context.actions, branch=invocation_context.branch, + live_session_id=live_session_id, ) return function_response_event @@ -1199,6 +1212,7 @@ def merge_parallel_function_response_events( branch=base_event.branch, content=types.Content(role='user', parts=merged_parts), actions=merged_actions, # Aggregated from all parallel events + live_session_id=base_event.live_session_id, ) # Use the base_event as the timestamp diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index f63cefeb45..55d0cd3c69 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -502,6 +502,34 @@ def simple_fn(**kwargs) -> dict: assert result_live is not None +@pytest.mark.asyncio +async def test_live_function_response_preserves_live_session_id(): + def simple_fn() -> dict: + return {'result': 'ok'} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent(name='test_agent', model=model, tools=[tool]) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + function_call = types.FunctionCall(name=tool.name, args={}) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=types.Content(parts=[types.Part(function_call=function_call)]), + live_session_id='live-session-1', + ) + + result = await handle_function_calls_live( + invocation_context, event, {tool.name: tool} + ) + + assert result is not None + assert result.live_session_id == 'live-session-1' + + @pytest.mark.asyncio async def test_function_call_args_copy_behavior(): """Test that modifying the copied args doesn't affect the original.""" @@ -1020,6 +1048,7 @@ def test_merge_parallel_function_response_events_preserves_other_attributes(): invocation_id=invocation_id, author=base_author, branch=base_branch, + live_session_id='live-session-1', content=types.Content( role='user', parts=[types.Part(function_response=function_response1)] ), @@ -1041,6 +1070,7 @@ def test_merge_parallel_function_response_events_preserves_other_attributes(): assert merged_event.invocation_id == invocation_id assert merged_event.author == base_author assert merged_event.branch == base_branch + assert merged_event.live_session_id == 'live-session-1' # Should contain both function responses assert len(merged_event.content.parts) == 2