b970215e创建于 12 小时前历史提交
import numpy as np
import pytest

from patchcore import common


def test_calling_without_setting_index():
    query = np.arange(3 * 6, dtype=np.float32).reshape(3, 6)
    index = 2 * query

    nn_search = common.FaissNN()

    distances_before_set_index, nn_indices_before_set_index = nn_search.run(
        2, query, index
    )
    nn_search.fit(index)
    distances_after_set_index, nn_indices_after_set_index = nn_search.run(2, query)

    assert np.all(distances_before_set_index == distances_after_set_index)
    assert np.all(nn_indices_before_set_index == nn_indices_after_set_index)


def test_approximate_faiss():
    query = np.ones([768, 128], dtype=np.float32)
    index = 2 * query

    nn_search = common.ApproximateFaissNN()

    distances_before_set_index, nn_indices_before_set_index = nn_search.run(
        2, query, index
    )
    nn_search.fit(index)
    distances_after_set_index, nn_indices_after_set_index = nn_search.run(2, query)

    assert np.all(distances_before_set_index == distances_after_set_index)
    assert np.all(nn_indices_before_set_index == nn_indices_after_set_index)


def test_search_without_index_raises_exception():
    features = np.arange(3 * 6, dtype=np.float32).reshape(3, 6)
    nn_search = common.FaissNN(on_gpu=False, num_workers=4)
    with pytest.raises(AttributeError):
        nn_search.run(2, features)
    assert nn_search.run(2, features, features) is not None


def test_read_write_index(tmpdir):
    index_filename = (tmpdir / "index").strpath
    nn_model = common.FaissNN()
    features = np.arange(3 * 6, dtype=np.float32).reshape(3, 6)
    nn_model.fit(features)
    nn_model.save(index_filename)

    loaded_nn_model = common.FaissNN()
    loaded_nn_model.load(index_filename)

    query_features = np.arange(10 * 6, dtype=np.float32).reshape(10, 6)
    assert loaded_nn_model.run(2, query_features) is not None
    assert np.all(
        loaded_nn_model.run(2, query_features)[0] == nn_model.run(2, query_features)[0]
    )
    assert np.all(
        loaded_nn_model.run(2, query_features)[1] == nn_model.run(2, query_features)[1]
    )


def test_average_merger_shape():
    input_features = []
    input_features.append(np.arange(2 * 3 * 4 * 5).reshape([2, 3, 4, 5]))
    input_features.append(2 * np.arange(2 * 3 * 4 * 5).reshape([2, 4, 3, 5]))

    merger = common.AverageMerger()
    output_features = merger.merge([input_features[0]])
    assert np.all(output_features.shape == (2, 3))

    merger = common.AverageMerger()
    output_features = merger.merge(input_features)
    assert np.all(output_features.shape == (2, 7))


def test_average_merger_output():
    input_features = [np.ones([2, 3, 4, 5])]

    merger = common.AverageMerger()
    output_features = merger.merge(input_features)
    assert np.all(output_features == 1.0)


def test_concat_merger_shape():
    input_features = []
    input_features.append(np.arange(2 * 3 * 4 * 5).reshape([2, 3, 4, 5]))
    input_features.append(2 * np.arange(2 * 3 * 4 * 5).reshape([2, 4, 3, 5]))

    merger = common.ConcatMerger()
    output_features = merger.merge([input_features[0]])
    assert np.all(output_features.shape == (2, 3 * 4 * 5))

    merger = common.ConcatMerger()
    output_features = merger.merge(input_features)
    assert np.all(output_features.shape == (2, 3 * 4 * 5 + 4 * 3 * 5))


def test_concat_merger_output():
    input_features = []
    input_features.append(np.ones([2, 3, 4, 5]))
    input_features.append(2 * np.ones([2, 3, 4, 5]))

    merger = common.ConcatMerger()
    output_features = merger.merge([input_features[0]])
    assert np.all(output_features == 1.0)

    merger = common.ConcatMerger()
    output_features = merger.merge(input_features)
    assert np.all(output_features[:, : 3 * 4 * 5] == 1.0)
    assert np.all(output_features[:, 3 * 4 * 5 :] == 2.0)