"""Tests for web_ui.web_ui_start module."""
from __future__ import annotations
import os
import socket
from unittest.mock import patch
import web_ui.web_ui_start as web_ui_start
from web_ui.web_ui_start import ensure_localhost_bypass_proxy, get_bind_address, main
AF_INET = socket.AF_INET
AF_INET6 = socket.AF_INET6
class TestEnsureLocalhostBypassProxy:
"""Tests for ensure_localhost_bypass_proxy function."""
def test_ensure_localhost_adds_to_empty_env(self) -> None:
"""Test adding bypass items to empty environment."""
for key in ("NO_PROXY", "no_proxy"):
os.environ.pop(key, None)
ensure_localhost_bypass_proxy()
no_proxy = os.environ.get("NO_PROXY", "")
assert "localhost" in no_proxy
assert "127.0.0.1" in no_proxy
assert "::1" in no_proxy
def test_ensure_localhost_preserves_existing(self) -> None:
"""Test that existing bypass items are preserved."""
os.environ.pop("NO_PROXY", None)
os.environ.pop("no_proxy", None)
os.environ["NO_PROXY"] = "example.com,another.org"
ensure_localhost_bypass_proxy()
no_proxy = os.environ.get("NO_PROXY", "")
assert "example.com" in no_proxy
assert "another.org" in no_proxy
assert "localhost" in no_proxy
assert "127.0.0.1" in no_proxy
def test_ensure_localhost_does_not_duplicate(self) -> None:
"""Test that items are not duplicated."""
os.environ["NO_PROXY"] = "localhost,127.0.0.1"
ensure_localhost_bypass_proxy()
no_proxy = os.environ.get("NO_PROXY", "")
parts = no_proxy.split(",")
assert parts.count("localhost") == 1
assert parts.count("127.0.0.1") == 1
def test_ensure_localhost_handles_whitespace(self) -> None:
"""Test handling of whitespace in environment variable."""
os.environ["NO_PROXY"] = "example.com , another.org"
ensure_localhost_bypass_proxy()
no_proxy = os.environ.get("NO_PROXY", "")
assert "example.com" in no_proxy
assert "localhost" in no_proxy
class TestGetBindAddress:
"""Tests for get_bind_address function."""
def test_get_bind_address_integration(self) -> None:
"""Integration test: call get_bind_address without mocks."""
result = get_bind_address()
assert result in ["127.0.0.1", "[::1]"]
def test_get_bind_address_ipv4_fallback_via_monkeypatch(self, monkeypatch) -> None:
"""Test IPv6 fallback when IPv4 fails using monkeypatch for real coverage."""
class FakeSocket:
def __init__(self, family, *args):
self.family = family
self.closed = False
def bind(self, address):
if self.family == 2:
raise OSError("IPv4 unavailable")
def close(self):
self.closed = True
import socket as socket_module
monkeypatch.setattr(socket_module, "socket", FakeSocket)
monkeypatch.setattr(web_ui_start.socket, "socket", FakeSocket)
result = get_bind_address()
assert result == "[::1]"
def test_get_bind_address_both_fail_via_monkeypatch(self, monkeypatch) -> None:
"""Test default return when both IPv4 and IPv6 fail using monkeypatch."""
class AlwaysFailSocket:
def __init__(self, family, *args):
pass
def bind(self, address):
raise OSError("Network unavailable")
def close(self):
pass
import socket as socket_module
monkeypatch.setattr(socket_module, "socket", AlwaysFailSocket)
monkeypatch.setattr(web_ui_start.socket, "socket", AlwaysFailSocket)
result = get_bind_address()
assert result == "127.0.0.1"
class TestMain:
"""Tests for main function."""
@patch("web_ui.web_ui_start.launch_app")
@patch("web_ui.web_ui_start.ensure_localhost_bypass_proxy")
@patch("web_ui.web_ui_start.get_bind_address", return_value="127.0.0.1")
@patch("sys.argv", ["web_ui_start.py"])
def test_main_default_args(self, mock_get_addr, mock_ensure, mock_launch) -> None:
"""Test main with default arguments."""
main()
mock_ensure.assert_called_once()
mock_launch.assert_called_once_with(
server_name="127.0.0.1",
server_port=2345,
share=False,
)
@patch("web_ui.web_ui_start.launch_app")
@patch("web_ui.web_ui_start.ensure_localhost_bypass_proxy")
@patch("web_ui.web_ui_start.get_bind_address", return_value="127.0.0.1")
@patch("sys.argv", ["web_ui_start.py", "--port", "8080"])
def test_main_custom_args(self, mock_get_addr, mock_ensure, mock_launch) -> None:
"""Test main with custom arguments."""
main()
mock_launch.assert_called_once_with(
server_name="127.0.0.1",
server_port=8080,
share=False,
)
@patch("web_ui.web_ui_start.launch_app")
@patch("web_ui.web_ui_start.ensure_localhost_bypass_proxy")
@patch("web_ui.web_ui_start.get_bind_address", return_value="127.0.0.1")
@patch("sys.argv", ["web_ui_start.py", "--share"])
def test_main_with_share(self, mock_get_addr, mock_ensure, mock_launch) -> None:
"""Test main with share flag."""
main()
mock_launch.assert_called_once_with(
server_name="127.0.0.1",
server_port=2345,
share=True,
)
@patch("web_ui.web_ui_start.launch_app")
@patch("web_ui.web_ui_start.ensure_localhost_bypass_proxy")
@patch("web_ui.web_ui_start.get_bind_address", return_value="127.0.0.1")
@patch.dict(os.environ, {"GRADIO_SERVER_PORT": "9000"})
@patch("sys.argv", ["web_ui_start.py"])
def test_main_with_env_vars(self, mock_get_addr, mock_ensure, mock_launch) -> None:
"""Test main with GRADIO_SERVER_PORT environment variable."""
main()
mock_launch.assert_called_once_with(
server_name="127.0.0.1",
server_port=9000,
share=False,
)
@patch("web_ui.web_ui_start.launch_app")
@patch("web_ui.web_ui_start.ensure_localhost_bypass_proxy")
@patch("web_ui.web_ui_start.get_bind_address", return_value="[::1]")
@patch("sys.argv", ["web_ui_start.py"])
def test_main_with_ipv6_detection(self, mock_get_addr, mock_ensure, mock_launch) -> None:
"""Test main when get_bind_address detects IPv6 only."""
main()
mock_launch.assert_called_once_with(
server_name="[::1]",
server_port=2345,
share=False,
)
@patch("web_ui.web_ui_start.launch_app")
@patch("web_ui.web_ui_start.ensure_localhost_bypass_proxy")
@patch("sys.argv", ["web_ui_start.py"])
def test_main_calls_ensure_bypass(self, mock_ensure, mock_launch) -> None:
"""Test that main calls ensure_localhost_bypass_proxy."""
main()
mock_ensure.assert_called_once()