Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ async def handle_function_calls_live(
tools_dict,
agent,
streaming_lock,
function_call_event.live_session_id,
)
)
for function_call in function_calls
Expand Down Expand Up @@ -669,6 +670,7 @@ async def _execute_single_function_call_live(
tools_dict: dict[str, BaseTool],
agent: LlmAgent,
streaming_lock: asyncio.Lock,
live_session_id: Optional[str],
) -> Optional[Event]:
"""Execute a single function call for live mode with thread safety."""

Expand Down Expand Up @@ -726,7 +728,11 @@ async def _run_on_tool_error_callbacks(
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
tool,
error_response,
tool_context,
invocation_context,
live_session_id=live_session_id,
)
raise tool_error

Expand Down Expand Up @@ -823,7 +829,11 @@ async def _run_with_trace():

# Builds the function response event.
function_response_event = __build_response_event(
tool, function_response, tool_context, invocation_context
tool,
function_response,
tool_context,
invocation_context,
live_session_id=live_session_id,
)
return function_response_event

Expand Down Expand Up @@ -1108,6 +1118,8 @@ def __build_response_event(
function_result: dict[str, object],
tool_context: ToolContext,
invocation_context: InvocationContext,
*,
live_session_id: Optional[str] = None,
) -> Event:
# Specs requires the result to be a dict.
if not isinstance(function_result, dict):
Expand Down Expand Up @@ -1137,6 +1149,7 @@ def __build_response_event(
content=content,
actions=tool_context.actions,
branch=invocation_context.branch,
live_session_id=live_session_id,
)

return function_response_event
Expand Down Expand Up @@ -1199,6 +1212,7 @@ def merge_parallel_function_response_events(
branch=base_event.branch,
content=types.Content(role='user', parts=merged_parts),
actions=merged_actions, # Aggregated from all parallel events
live_session_id=base_event.live_session_id,
)

# Use the base_event as the timestamp
Expand Down
30 changes: 30 additions & 0 deletions tests/unittests/flows/llm_flows/test_functions_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,34 @@ def simple_fn(**kwargs) -> dict:
assert result_live is not None


@pytest.mark.asyncio
async def test_live_function_response_preserves_live_session_id():
def simple_fn() -> dict:
return {'result': 'ok'}

tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(name='test_agent', model=model, tools=[tool])
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)

function_call = types.FunctionCall(name=tool.name, args={})
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=types.Content(parts=[types.Part(function_call=function_call)]),
live_session_id='live-session-1',
)

result = await handle_function_calls_live(
invocation_context, event, {tool.name: tool}
)

assert result is not None
assert result.live_session_id == 'live-session-1'


@pytest.mark.asyncio
async def test_function_call_args_copy_behavior():
"""Test that modifying the copied args doesn't affect the original."""
Expand Down Expand Up @@ -1020,6 +1048,7 @@ def test_merge_parallel_function_response_events_preserves_other_attributes():
invocation_id=invocation_id,
author=base_author,
branch=base_branch,
live_session_id='live-session-1',
content=types.Content(
role='user', parts=[types.Part(function_response=function_response1)]
),
Expand All @@ -1041,6 +1070,7 @@ def test_merge_parallel_function_response_events_preserves_other_attributes():
assert merged_event.invocation_id == invocation_id
assert merged_event.author == base_author
assert merged_event.branch == base_branch
assert merged_event.live_session_id == 'live-session-1'

# Should contain both function responses
assert len(merged_event.content.parts) == 2
Expand Down