/*
* Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
%module ascendfaiss;
%include <stdint.i>
typedef uint64_t size_t;
#define __restrict
/*******************************************************************
* Copied verbatim to wrapper. Contains the C++-visible includes, and
* the language includes for their respective matrix libraries.
*******************************************************************/
%{
#include <stdint.h>
#ifdef SWIGPYTHON
#undef popcount64
#define SWIG_FILE_WITH_INIT
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/arrayobject.h>
#endif
#include <faiss/ascend/AscendCloner.h>
#include <faiss/ascend/AscendIndexFlat.h>
#include <faiss/ascend/AscendIndexIVFSQ.h>
#include <faiss/ascend/AscendIndexInt8Flat.h>
#include <faiss/ascend/AscendIndexSQ.h>
#include <faiss/ascend/AscendIndexIVFSP.h>
#include <faiss/ascend/custom/AscendIndexIVFSQT.h>
#include <faiss/ascend/AscendIndexBinaryFlat.h>
#include <faiss/impl/AuxIndexStructures.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/ascend/utils/Version.h>
%}
/********************************************************
* GIL manipulation and exception handling
********************************************************/
#ifdef SWIGPYTHON
// %catches(faiss::FaissException);
// Python-specific: release GIL by default for all functions
%exception {
Py_BEGIN_ALLOW_THREADS
try {
$action
} catch(faiss::FaissException & e) {
PyEval_RestoreThread(_save);
if (PyErr_Occurred()) {
// some previous code already set the error type.
} else {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
SWIG_fail;
} catch(std::bad_alloc & ba) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_MemoryError, "std::bad_alloc");
SWIG_fail;
} catch(const std::exception& ex) {
PyEval_RestoreThread(_save);
std::string what = std::string("C++ exception ") + ex.what();
PyErr_SetString(PyExc_RuntimeError, what.c_str());
SWIG_fail;
}
Py_END_ALLOW_THREADS
}
#endif
/*******************************************************************
* Types of vectors we want to manipulate at the scripting language
* level.
*******************************************************************/
// simplified interface for vector
namespace std {
template<class T>
class vector {
public:
vector();
void push_back(T);
void clear();
T * data();
size_t size();
T at (size_t n) const;
T & operator [] (size_t n);
void resize (size_t n);
void swap (vector<T> & other);
};
};
%include <std_string.i>
%include <std_pair.i>
%include <std_map.i>
%include <std_shared_ptr.i>
// primitive array types
%template(FloatVector) std::vector<float>;
%template(DoubleVector) std::vector<double>;
%template(ByteVector) std::vector<uint8_t>;
%template(CharVector) std::vector<char>;
%template(UInt64Vector) std::vector<unsigned long>;
%template(LongVector) std::vector<long>;
%template(IntVector) std::vector<int>;
%template(FaissIdxVector) std::vector<faiss::idx_t>;
%template(FloatVectorVector) std::vector<std::vector<float> >;
%template(ByteVectorVector) std::vector<std::vector<unsigned char> >;
%template(CharVectorVector) std::vector<std::vector<char> >;
%template(LongVectorVector) std::vector<std::vector<long> >;
%ignore faiss::ascend::AscendIndexConfig::AscendIndexConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexIVFConfig::AscendIndexIVFConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexIVFSQConfig::AscendIndexIVFSQConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexInt8Config::AscendIndexInt8Config(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexInt8FlatConfig::AscendIndexInt8FlatConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexSQConfig::AscendIndexSQConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexFlatConfig::AscendIndexFlatConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexIVFSPConfig::AscendIndexIVFSPConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexIVFSQTConfig::AscendIndexIVFSQTConfig(std::initializer_list<int>, int);
%ignore faiss::ascend::AscendIndexBinaryFlatConfig::AscendIndexBinaryFlatConfig(std::initializer_list<int>, int);
%include <faiss/MetricType.h>
%include <faiss/impl/platform_macros.h>
%include <faiss/Index.h>
%include <faiss/IndexBinary.h>
%include <faiss/Clustering.h>
%include <faiss/ascend/AscendIndex.h>
%include <faiss/ascend/AscendIndexIVF.h>
%include <faiss/ascend/AscendIndexIVFSQ.h>
%include <faiss/ascend/AscendIndexInt8.h>
%include <faiss/ascend/AscendIndexInt8Flat.h>
%include <faiss/ascend/AscendIndexFlat.h>
%include <faiss/ascend/AscendIndexSQ.h>
%include <faiss/ascend/AscendIndexIVFSP.h>
%include <faiss/ascend/custom/AscendIndexIVFSQT.h>
%include <faiss/ascend/AscendIndexBinaryFlat.h>
%include <faiss/clone_index.h>
%include <faiss/ascend/AscendClonerOptions.h>
%include <faiss/ascend/utils/Version.h>
%ignore faiss::InvertedListScanner;
%ignore faiss::ScalarQuantizer::SQDistanceComputer;
%ignore faiss::ScalarQuantizer::Quantizer;
%ignore faiss::ScalarQuantizer::select_InvertedListScanner;
%ignore faiss::ScalarQuantizer::get_distance_computer;
%include <faiss/impl/ScalarQuantizer.h>
%ignore faiss::BufferList::Buffer;
%ignore faiss::RangeSearchPartialResult::QueryResult;
%ignore faiss::IDSelectorBatch::set;
%ignore faiss::IDSelectorBatch::bloom;
%ignore faiss::InterruptCallback::instance;
%ignore faiss::InterruptCallback::lock;
%include <faiss/impl/AuxIndexStructures.h>
%include <faiss/impl/IDSelector.h>
#ifdef SWIGPYTHON
%define DOWNCAST(subclass)
if (dynamic_cast<faiss::subclass *> ($1)) {
$result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## subclass,$owner);
} else
%enddef
%define DOWNCAST2(subclass, longname)
if (dynamic_cast<faiss::subclass *> ($1)) {
$result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ ## longname,$owner);
} else
%enddef
%define DOWNCAST_ASCEND(subclass)
if (dynamic_cast<faiss::ascend::subclass *> ($1)) {
$result = SWIG_NewPointerObj($1,SWIGTYPE_p_faiss__ascend__ ## subclass,$owner);
} else
%enddef
#endif
// Subclasses should appear before their parent
%typemap(out) faiss::Index * {
DOWNCAST_ASCEND ( AscendIndexFlat )
DOWNCAST_ASCEND ( AscendIndexIVFSQ )
DOWNCAST_ASCEND ( AscendIndexSQ )
DOWNCAST_ASCEND ( AscendIndexIVFSP )
DOWNCAST_ASCEND ( AscendIndexIVFSQT )
// default for non-recognized classes
DOWNCAST ( Index )
if ($1 == NULL)
{
#ifdef SWIGPYTHON
$result = SWIG_Py_Void();
#endif
} else {
assert(false);
}
}
%typemap(out) faiss::IndexBinary * {
DOWNCAST_ASCEND ( AscendIndexBinaryFlat )
if ($1 == NULL)
{
#ifdef SWIGPYTHON
$result = SWIG_Py_Void();
#endif
} else {
assert(false);
}
}
%typemap(out) faiss::ascend::AscendIndexInt8 * {
DOWNCAST_ASCEND ( AscendIndexInt8Flat )
if ($1 == NULL)
{
#ifdef SWIGPYTHON
$result = SWIG_Py_Void();
#endif
} else {
assert(false);
}
}
// just to downcast pointers that come from elsewhere (eg. direct
// access to object fields)
%inline %{
faiss::Index * downcast_index (faiss::Index *index)
{
return index;
}
%}
%newobject index_ascend_to_cpu;
%newobject index_cpu_to_ascend;
%newobject index_int8_ascend_to_cpu;
%newobject index_int8_cpu_to_ascend;
%include <faiss/ascend/AscendCloner.h>
/*******************************************************************
* Python specific: numpy array <-> C++ pointer interface
*******************************************************************/
#ifdef SWIGPYTHON
%{
PyObject *swig_ptr (PyObject *a)
{
if(!PyArray_Check(a)) {
PyErr_SetString(PyExc_ValueError, "input not a numpy array");
return NULL;
}
PyArrayObject *ao = (PyArrayObject *)a;
if(!PyArray_ISCONTIGUOUS(ao)) {
PyErr_SetString(PyExc_ValueError, "array is not C-contiguous");
return NULL;
}
void * data = PyArray_DATA(ao);
if(PyArray_TYPE(ao) == NPY_FLOAT32) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_float, 0);
}
if(PyArray_TYPE(ao) == NPY_FLOAT64) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_double, 0);
}
if(PyArray_TYPE(ao) == NPY_UINT8) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_char, 0);
}
if(PyArray_TYPE(ao) == NPY_INT8) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_char, 0);
}
if(PyArray_TYPE(ao) == NPY_INT32) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_int, 0);
}
if(PyArray_TYPE(ao) == NPY_UINT64) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_unsigned_long, 0);
}
if(PyArray_TYPE(ao) == NPY_INT64) {
return SWIG_NewPointerObj(data, SWIGTYPE_p_long, 0);
}
PyErr_SetString(PyExc_ValueError, "did not recognize array type");
return NULL;
}
%}
%init %{
/* needed, else crash at runtime */
import_array();
%}
// return a pointer usable as input for functions that expect pointers
PyObject *swig_ptr (PyObject *a);
%define REV_SWIG_PTR(ctype, numpytype)
%{
PyObject * rev_swig_ptr(ctype *src, npy_intp size) {
return PyArray_SimpleNewFromData(1, &size, numpytype, src);
}
%}
PyObject * rev_swig_ptr(ctype *src, size_t size);
%enddef
REV_SWIG_PTR(float, NPY_FLOAT32);
REV_SWIG_PTR(unsigned char, NPY_UINT8);
REV_SWIG_PTR(char, NPY_INT8);
REV_SWIG_PTR(int, NPY_INT32);
REV_SWIG_PTR(int64_t, NPY_INT64);
REV_SWIG_PTR(uint64_t, NPY_UINT64);
#endif
// End of file...