From 8bb81ba1be8ffe7ebe96928d76e3ab3ab998365c Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 27 May 2026 14:07:53 -0700 Subject: [PATCH 1/3] kafka.net.transport: Always track hostname; KafkaSSLTransport ssl_check_hostname bool --- kafka/net/manager.py | 6 +++--- kafka/net/transport.py | 8 +++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/kafka/net/manager.py b/kafka/net/manager.py index f8e901402..c3b47e0e8 100644 --- a/kafka/net/manager.py +++ b/kafka/net/manager.py @@ -213,10 +213,10 @@ async def _build_transport(self, node): self.config['socket_options'], proxy_url=self.config['proxy_url']) if self.ssl_enabled: - hostname = node.host if self.config['ssl_check_hostname'] else None - transport = KafkaSSLTransport(self._net, sock, self._build_ssl_context(), hostname) + transport = KafkaSSLTransport(self._net, sock, self._build_ssl_context(), + host=node.host, ssl_check_hostname=self.config['ssl_check_hostname']) else: - transport = KafkaTCPTransport(self._net, sock) + transport = KafkaTCPTransport(self._net, sock, host=node.host) try: await transport.handshake() diff --git a/kafka/net/transport.py b/kafka/net/transport.py index e3a3c772d..cac51e031 100644 --- a/kafka/net/transport.py +++ b/kafka/net/transport.py @@ -12,9 +12,10 @@ class KafkaTCPTransport: - def __init__(self, net, sock): + def __init__(self, net, sock, host=None): self._net = net self._sock = sock + self.host = host self._closed = False self._write_buffer = deque() self._writing = False @@ -351,11 +352,12 @@ def __str__(self): class KafkaSSLTransport(KafkaTCPTransport): - def __init__(self, net, sock, ssl_context, server_hostname=None): + def __init__(self, net, sock, ssl_context, host=None, ssl_check_hostname=False): self._ssl_context = ssl_context + server_hostname = host if ssl_check_hostname else None sock = ssl_context.wrap_socket( sock, server_hostname=server_hostname, do_handshake_on_connect=False) - super().__init__(net, sock) + super().__init__(net, sock, host=host) async def handshake(self): while True: From b0ba69c15357fad5568da85609da9b58ef271e5e Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 27 May 2026 14:09:49 -0700 Subject: [PATCH 2/3] SASL: Prefer node hostname to IP address when building mechanisms --- kafka/net/connection.py | 6 +++- test/net/test_connection.py | 60 +++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/kafka/net/connection.py b/kafka/net/connection.py index 71434adb0..4cc74c0c5 100644 --- a/kafka/net/connection.py +++ b/kafka/net/connection.py @@ -417,9 +417,13 @@ async def _sasl_authenticate(self): # Step 2: SASL authentication exchange version = response.API_VERSION + # Prefer the configured hostname (stored on the transport) so that + # mechanisms like GSSAPI construct service principals against the + # user-supplied name, not whichever IP getaddrinfo handed us. + sasl_host = self.transport.host if self.transport.host else self.transport.getPeer()[0] try: mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])( - host=self.transport.getPeer()[0], **self.config) + host=sasl_host, **self.config) except Exception as exc: self.close(exc) return diff --git a/test/net/test_connection.py b/test/net/test_connection.py index a6d481f52..9ed19baac 100644 --- a/test/net/test_connection.py +++ b/test/net/test_connection.py @@ -469,3 +469,63 @@ def mock_send_request(request): net.run(conn._sasl_authenticate()) transport.abort.assert_called_once() + + def _drive_handshake_with_recording_mechanism(self, net, conn): + from kafka.protocol.sasl import SaslHandshakeRequest + api_versions = {SaslHandshakeRequest[0].API_KEY: (0, 1)} + conn.broker_version_data = BrokerVersionData(api_versions=api_versions) + handshake_response = MagicMock() + handshake_response.error_code = 0 + handshake_response.mechanisms = ['PLAIN'] + auth_response = MagicMock() + auth_response.error_code = 0 + auth_response.auth_bytes = b'' + responses = iter([handshake_response, auth_response]) + def mock_send_request(_): + f = Future() + f.success(next(responses)) + return f + conn._send_request = mock_send_request + + captured = {} + from kafka.sasl import register_sasl_mechanism + from kafka.sasl.plain import SaslMechanismPlain + + class RecordingPlain(SaslMechanismPlain): + def __init__(self, **config): + captured['host'] = config.get('host') + super().__init__(**config) + register_sasl_mechanism('PLAIN', RecordingPlain, overwrite=True) + try: + net.run(conn._sasl_authenticate()) + finally: + register_sasl_mechanism('PLAIN', SaslMechanismPlain, overwrite=True) + return captured + + def test_sasl_uses_transport_host_for_mechanism(self, net): + conn = KafkaConnection( + net, node_id='test', + security_protocol='SASL_PLAINTEXT', sasl_mechanism='PLAIN', + sasl_plain_username='user', sasl_plain_password='pass') + transport = MagicMock() + transport.host = 'kafka.example.com' + transport.getPeer.return_value = ('10.0.0.1', 9092) + conn.transport = transport + conn.initializing = True + + captured = self._drive_handshake_with_recording_mechanism(net, conn) + assert captured['host'] == 'kafka.example.com' + + def test_sasl_falls_back_to_peer_ip_when_transport_host_unset(self, net): + conn = KafkaConnection( + net, node_id='test', + security_protocol='SASL_PLAINTEXT', sasl_mechanism='PLAIN', + sasl_plain_username='user', sasl_plain_password='pass') + transport = MagicMock() + transport.host = None + transport.getPeer.return_value = ('10.0.0.1', 9092) + conn.transport = transport + conn.initializing = True + + captured = self._drive_handshake_with_recording_mechanism(net, conn) + assert captured['host'] == '10.0.0.1' From 89374125e580e559a7a25096f91186ed1604cf9d Mon Sep 17 00:00:00 2001 From: Dana Powers Date: Wed, 27 May 2026 14:41:31 -0700 Subject: [PATCH 3/3] drive-by fix for transport.host_port() when _sock is None --- kafka/net/transport.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kafka/net/transport.py b/kafka/net/transport.py index cac51e031..3b8e0606f 100644 --- a/kafka/net/transport.py +++ b/kafka/net/transport.py @@ -336,6 +336,8 @@ async def handshake(self): pass def host_port(self): + if self._sock is None: + return 'none' try: host, port = self._sock.getpeername()[0:2] except (OSError, ValueError):