* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "runtime/libentry/libentry.h"
using namespace libentry;
void libentry::ArgProcessor::classifyArguments(
const py::list& args,
const py::dict& kwargs,
const py::list& jit_params,
const std::unordered_set<int>& specialize_indices,
const std::unordered_set<int>& do_not_specialize_indices)
{
for (size_t i = 0; i < args.size(); ++i) {
if (specialize_indices.count(i)) {
k_args_.append(args[i]);
spec_args_.append(args[i]);
} else if (do_not_specialize_indices.count(i)) {
k_args_.append(args[i]);
dns_args_.append(args[i]);
} else {
const_args_.append(args[i]);
}
}
for (size_t i = args.size(); i < jit_params.size(); ++i) {
const py::object& param = jit_params[i];
py::object val;
if (kwargs.contains(param.attr("name"))) {
val = kwargs[param.attr("name")];
} else if (py::hasattr(param, "default") && !param.attr("default").is_none()) {
val = param.attr("default");
} else {
continue;
}
if (param.attr("is_constexpr").cast<py::bool_>()) {
const_args_.append(val);
} else if (param.attr("do_not_specialize").cast<py::bool_>()) {
dns_args_.append(val);
k_args_.append(val);
} else {
spec_args_.append(val);
k_args_.append(val);
}
}
}
KeyType libentry::ArgProcessor::generateKey()
{
auto is_tensor = [](py::handle x) {
return py::hasattr(x, "data_ptr");
};
auto is_int = [](py::handle x) {
return py::isinstance<py::int_>(x);
};
py::list spec_key;
for (auto arg : spec_args_) {
if (is_tensor(arg)) {
auto dtype = arg.attr("dtype");
uintptr_t data_ptr = arg.attr("data_ptr")().cast<uintptr_t>();
bool aligned = (data_ptr & (divisibility_ - 1)) == 0;
spec_key.append(py::make_tuple(dtype, aligned));
} else {
spec_key.append(py::make_tuple(py::type::of(arg), arg));
}
}
py::list dns_key;
for (auto arg : dns_args_) {
if (is_tensor(arg)) {
dns_key.append(arg.attr("dtype"));
} else if (!is_int(arg)) {
dns_key.append(py::type::of(arg));
} else {
int64_t val = arg.cast<int64_t>();
if (val >= -0x80000000LL && val <= 0x7FFFFFFFLL) {
dns_key.append(py::str("i32"));
} else if (val >= 0 && val <= 0xFFFFFFFFFFFFFFFFLL) {
dns_key.append(py::str("u64"));
} else {
dns_key.append(py::str("i64"));
}
}
}
py::list result;
auto list_append = [&](const py::list& src) {
for (auto handle : src) {
result.append(handle);
}
};
list_append(spec_key);
list_append(dns_key);
list_append(const_args_);
return result;
}
py::list libentry::ArgProcessor::getKArgs()
{
return k_args_;
}