import socket
import pytest
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.utils.common_utils.url_utils import (
fix_domain_path_merge,
normalize_path,
normalize_domain,
normalize_url,
are_similar_urls,
validate_runtime_request_url,
)
scheme = "https"
host = "example.com"
path = "/search"
query_key = "q"
query_value = "A" * 8200
url = f"{scheme}://{host}{path}?{query_key}={query_value}"
@pytest.mark.parametrize("domain, expected", [
("example.com", "example.com"),
("sub.example.com", "sub.example.com"),
("example.com-www", "example.com"),
("example.net-api", "example.net"),
("example.org-docs", "example.org"),
("example.edu-courses", "example.edu"),
("example.gov-info", "example.gov"),
("example-api.com", "example-api.com"),
("example-docs.com", "example-docs.com"),
("example-www", "example"),
("api.example.com-v1", "api.example.com-v1"),
])
def test_normalize_domain(domain, expected):
result = normalize_domain(domain)
assert result == expected
@pytest.mark.parametrize("path, expected", [
("/path/to/resource", "/path/to/resource"),
("path/to/resource", "/path/to/resource"),
("/path//to//resource", "/path/to/resource"),
("//path//to//resource", "/path/to/resource"),
("a", "/a"),
("/", "/"),
("", "/"),
("/api/v1/users//123/", "/api/v1/users/123/"),
])
def test_normalize_path(path, expected):
result = normalize_path(path)
assert result == expected
@pytest.mark.parametrize("invalid_value, error_code, error_msg_fragment", [
(url, 200021, "URL length must be less than 8192"),
])
def test_normalize_path_length_exceeded(invalid_value, error_code, error_msg_fragment):
with pytest.raises(CustomValueException) as exc_info:
normalize_path(invalid_value)
err_msg = str(exc_info.value)
assert exc_info.value.error_code == error_code
assert error_msg_fragment in err_msg
@pytest.mark.parametrize("url, expected", [
("https://example.com/path/to/resource", "https://example.com/path/to/resource"),
("http://example.com/path/to/resource", "http://example.com/path/to/resource"),
("https://example.com-api/v1/users", "https://example.com/api/v1/users"),
("https://example.org-docs/v2/data", "https://example.org/docs/v2/data"),
("https://example.net-api/v1/endpoints", "https://example.net/api/v1/endpoints"),
("https://api.example.com-v1/users/123", "https://api.example.com-v1/users/123"),
("https://sub.example.com-docs/v1/data/items", "https://sub.example.com/docs/v1/data/items"),
])
def test_fix_domain_path_merge(url, expected):
result = fix_domain_path_merge(url)
assert result == expected
@pytest.mark.parametrize("url, expected", [
("https://example.com/path/to/resource", "https://example.com/path/to/resource"),
("http://example.org/api/v1/data", "http://example.org/api/v1/data"),
("https://example.com-api/v1/users", "https://example.com/api/v1/users"),
("https://example.org-docs/v2/data", "https://example.org/docs/v2/data"),
("https://example.com//path//to//resource", "https://example.com/path/to/resource"),
("https://example.com/path//to//resource", "https://example.com/path/to/resource"),
("https://example.com-api//path//to//resource", "https://example.com/api/path/to/resource"),
("example.com/path/to/resource", "https:///example.com/path/to/resource"),
("https://example.com/api?query=test", "https://example.com/api?query=test"),
("https://example.com-api//api?query=test", "https://example.com/api/api?query=test"),
("https://example.com/page#section", "https://example.com/page#section"),
("https://example.com-api//page#section", "https://example.com/api/page#section"),
])
def test_normalize_url(url, expected):
result = normalize_url(url)
assert result == expected
@pytest.mark.parametrize("invalid_url, expected", [
("not-a-valid-url", "https:///not-a-valid-url"),
("://missing-protocol.com", "https:///:/missing-protocol.com"),
("https://", "https:///"),
("", "https:///"),
])
def test_normalize_url_invalid_urls(invalid_url, expected):
result = normalize_url(invalid_url)
assert result == expected
@pytest.mark.parametrize("url1, url2, expected", [
("https://example.com/path/to/resource", "https://example.com/path/to/resource", True),
("https://example.com-api//path//to//resource", "https://example.com/path/to/resource", True),
("http://example.org/api/v1/data", "example.org/api/v1/data", True),
("https://example.com/api?query=test", "https://example.com/api?query=other", True),
("https://example.com/page#section1", "https://example.com/page#section2", True),
("https://example.com/path/to/resource", "https://example.com/path/to/resource/", True),
("https://example.com/path/to/resource", "https://different.com/path/to/resource", False),
("https://example.com/path/to/resource", "https://example.com/different/path", False),
])
def test_are_similar_urls(url1, url2, expected):
result = are_similar_urls(url1, url2)
assert result == expected
@pytest.mark.parametrize("url1, url2, threshold, expected", [
("https://example.com/path/to/resource", "https://example.com/path/to/other", 0.5, True),
("https://example.com/path/to/resource", "https://example.com/path/to/other", 0.9, False),
("https://example.com/path/to/resource", "https://example.com/path/to/resource", 1.0, True),
])
def test_are_similar_urls_threshold(url1, url2, threshold, expected):
result = are_similar_urls(url1, url2, threshold)
assert result == expected
@pytest.mark.parametrize("url1, url2", [
("not-a-valid-url", "https://example.com/path"),
("https://example.com/path", "not-a-valid-url"),
("not-a-valid-url", "another-invalid-url"),
])
def test_are_similar_urls_invalid_urls(url1, url2):
result = are_similar_urls(url1, url2)
assert result is False
def _mock_getaddrinfo(*addresses):
def resolver(host, port, type=0):
return [
(
socket.AF_INET6 if ":" in address else socket.AF_INET,
socket.SOCK_STREAM,
6,
"",
(address, port),
)
for address in addresses
]
return resolver
def test_validate_runtime_request_url_blocks_dns_to_non_public_ip(monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", _mock_getaddrinfo("169.254.169.254"))
with pytest.raises(CustomValueException):
validate_runtime_request_url("http://metadata.attacker.test/latest/meta-data/")