Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,14 +507,19 @@ async def run_live(
attempt += 1
if not llm_request.live_connect_config:
llm_request.live_connect_config = types.LiveConnectConfig()
if not llm_request.live_connect_config.session_resumption:
session_resumption = (
llm_request.live_connect_config.session_resumption
)
if not session_resumption:
session_resumption = types.SessionResumptionConfig()
llm_request.live_connect_config.session_resumption = (
types.SessionResumptionConfig()
session_resumption
)
llm_request.live_connect_config.session_resumption.handle = (
session_resumption.handle = (
invocation_context.live_session_resumption_handle
)
llm_request.live_connect_config.session_resumption.transparent = True
if session_resumption.transparent is None:
session_resumption.transparent = True

logger.info(
'Establishing live connection for agent: %s',
Expand Down
72 changes: 72 additions & 0 deletions tests/unittests/flows/llm_flows/test_base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,78 @@ async def mock_receive_2():
assert invocation_context.live_session_resumption_handle == 'test_handle'


@pytest.mark.asyncio
async def test_run_live_reconnect_preserves_nontransparent_resumption():
"""Test that reconnect does not force transparent resumption."""
from google.adk.agents.live_request_queue import LiveRequestQueue
from websockets.exceptions import ConnectionClosed

real_model = Gemini()
mock_connection = mock.AsyncMock()

async def mock_receive():
yield LlmResponse(
live_session_resumption_update=types.LiveServerSessionResumptionUpdate(
new_handle='test_handle'
)
)
raise ConnectionClosed(None, None)

mock_connection.receive = mock.Mock(side_effect=mock_receive)

agent = Agent(name='test_agent', model=real_model)
invocation_context = await testing_utils.create_invocation_context(
agent=agent
)
invocation_context.live_request_queue = LiveRequestQueue()

flow = BaseLlmFlowForTesting()

async def mock_preprocess(ctx, req):
req.live_connect_config.session_resumption = types.SessionResumptionConfig(
transparent=False
)
if False:
yield

with mock.patch.object(
flow, '_preprocess_async', side_effect=mock_preprocess
):
with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock):
mock_connection_2 = mock.AsyncMock()

class StopError(Exception):
pass

async def mock_receive_2():
yield LlmResponse(
content=types.Content(parts=[types.Part.from_text(text='hi')])
)
raise StopError('stop')

mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2)

mock_aenter = mock.AsyncMock()
mock_aenter.side_effect = [mock_connection, mock_connection_2]

with mock.patch(
'google.adk.models.google_llm.Gemini.connect'
) as mock_connect:
mock_connect.return_value.__aenter__ = mock_aenter

try:
async for _ in flow.run_live(invocation_context):
pass
except StopError:
pass

reconnect_request = mock_connect.call_args_list[1].args[0]
assert (
reconnect_request.live_connect_config.session_resumption.transparent
is False
)


@pytest.mark.asyncio
async def test_run_live_skips_send_history_on_resumption():
"""Test that run_live skips send_history when resuming a session."""
Expand Down
Loading