#include "VisionML.h"
#include "KisOptionButtonStrip.h"
#include "KoColorSpace.h"
#include "KoJsonTrader.h"
#include "KoResourcePaths.h"
#include "kis_icon_utils.h"
#include "kis_paint_device.h"
#include <klocalizedstring.h>
#include <ksharedconfig.h>
#include <QCoreApplication>
#include <QDebug>
#include <QDesktopServices>
#include <QDir>
#include <QFileSystemWatcher>
#include <QHBoxLayout>
#include <QMessageBox>
#include <QMutexLocker>
#include <QString>
#include <QToolButton>
#include <QUrl>
#include <string>
#include <ggml-backend.h>
namespace
{
void handleGGMLFatalError(const char *error_message)
{
qCritical() << "[GGML] Fatal error:" << error_message;
throw visp::exception(error_message);
}
struct Paths {
QString plugin;
QString lib;
QString models;
} paths;
void initPaths()
{
QString user = KoResourcePaths::getAppDataLocation();
paths.plugin = user + "/pykrita/vision_tools/";
paths.lib = paths.plugin + "lib/";
paths.models = paths.plugin + "models/";
if (!QDir(paths.plugin).exists()) {
throw std::runtime_error("Plugin directory not found (expected at " + paths.plugin.toStdString() + ")");
}
}
void loadGGMLBackend(char const *name)
{
#if defined(WIN32)
char const *ext = "dll";
#elif defined(__APPLE__)
char const *ext = "dylib";
#else
char const *ext = "so";
#endif
QString path = QString("%1ggml-%2.%3").arg(paths.lib, name, ext);
ggml_backend_load(path.toUtf8().constData());
}
QString findModelPath(VisionMLTask task)
{
switch (task) {
case VisionMLTask::segmentation:
return paths.models + "sam";
case VisionMLTask::background_removal:
return paths.models + "birefnet";
case VisionMLTask::inpainting:
return paths.models + "migan";
default:
return paths.models;
}
}
char const *toString(VisionMLTask task)
{
switch (task) {
case VisionMLTask::segmentation:
return "segmentation";
case VisionMLTask::inpainting:
return "inpainting";
case VisionMLTask::background_removal:
return "background_removal";
default:
return "unknown";
}
}
void unloadFromGPU(visp::compute_graph &graph, visp::backend_type devType)
{
if (devType == visp::backend_type::gpu) {
graph = {};
}
}
}
QSharedPointer<VisionModels> VisionModels::create()
{
initPaths();
QSharedPointer<VisionModels> result(new VisionModels());
if (!result->m_backend) {
return nullptr;
}
return result;
}
VisionModels::VisionModels()
{
ggml_set_abort_callback(handleGGMLFatalError);
loadGGMLBackend("cpu");
loadGGMLBackend("vulkan");
m_config = KSharedConfig::openConfig()->group("VisionML");
QString backendString = m_config.readEntry("backend", "cpu");
visp::backend_type backendType = backendString == "gpu" ? visp::backend_type::gpu : visp::backend_type::cpu;
configureModel(VisionMLTask::segmentation, "sam/MobileSAM-F16.gguf");
configureModel(VisionMLTask::inpainting, "migan/MIGAN-512-places2-F16.gguf");
configureModel(VisionMLTask::background_removal, "birefnet/BiRefNet-lite-F16.gguf");
QString err = initialize(backendType);
if (!err.isEmpty()) {
backendType = backendType == visp::backend_type::gpu ? visp::backend_type::cpu : visp::backend_type::gpu;
err = initialize(backendType);
}
if (!err.isEmpty()) {
QMessageBox::warning(nullptr,
i18nc("@title:window", "Krita - VisionML Plugin"),
i18n("Failed to initialize AI tools plugin.\n") + err);
return;
}
connect(QCoreApplication::instance(), SIGNAL(aboutToQuit()), this, SLOT(cleanUp()));
}
void VisionModels::configureModel(VisionMLTask task, QString const &defaultName)
{
QString modelName = m_config.readEntry(QString("model_%1").arg(toString(task)), defaultName);
if (!QFile::exists(paths.models + modelName)) {
modelName = defaultName;
}
m_modelName[(int)task] = modelName;
}
QString VisionModels::initialize(visp::backend_type backendType)
{
QMutexLocker lock(&m_mutex);
m_sam = {};
m_birefnet = {};
m_migan = {};
try {
m_backend = visp::backend_init(backendType);
} catch (const std::exception &e) {
return QString(e.what());
}
m_backendType = backendType;
m_config.writeEntry("backend", backendType == visp::backend_type::gpu ? "gpu" : "cpu");
return QString();
}
void VisionModels::encodeSegmentationImage(visp::image_view const &image)
{
QMutexLocker lock(&m_mutex);
if (!m_sam.weights) {
unloadModels();
QByteArray path = modelPath(VisionMLTask::segmentation);
m_sam = visp::sam_load_model(path.data(), m_backend);
}
visp::sam_encode(m_sam, image);
}
bool VisionModels::hasSegmentationImage() const
{
return m_sam.input_image != nullptr;
}
visp::image_data VisionModels::predictSegmentationMask(visp::i32x2 point)
{
QMutexLocker lock(&m_mutex);
return visp::sam_compute(m_sam, point);
}
visp::image_data VisionModels::predictSegmentationMask(visp::box_2d box)
{
QMutexLocker lock(&m_mutex);
return visp::sam_compute(m_sam, box);
}
visp::image_data VisionModels::removeBackground(visp::image_view const &originalImage)
{
QMutexLocker lock(&m_mutex);
if (!m_birefnet.weights) {
QByteArray path = modelPath(VisionMLTask::background_removal);
m_birefnet = visp::birefnet_load_model(path.data(), m_backend);
}
visp::image_data resized;
visp::image_view image = originalImage;
if (m_birefnet.params.image_size == -1) {
int maxSide = std::max(image.extent[0], image.extent[1]);
if (maxSide > 2304 && m_backendType == visp::backend_type::cpu) {
float f = 2304.f / float(maxSide);
auto target = visp::i32x2{int(image.extent[0] * f), int(image.extent[1] * f)};
int m = m_birefnet.params.image_multiple;
target[0] = (target[0] + m - 1) / m * m;
target[1] = (target[1] + m - 1) / m * m;
resized = visp::image_scale(image, target);
image = resized;
}
}
auto result = visp::birefnet_compute(m_birefnet, image);
unloadFromGPU(m_birefnet.graph, m_backendType);
if (resized.data) {
result = visp::image_scale(result, originalImage.extent);
}
return result;
}
visp::image_data VisionModels::inpaint(visp::image_view const &image, visp::image_view const &mask)
{
QMutexLocker lock(&m_mutex);
if (!m_migan.weights) {
QByteArray path = modelPath(VisionMLTask::inpainting);
m_migan = visp::migan_load_model(path.data(), m_backend);
}
return visp::migan_compute(m_migan, image, mask);
}
void VisionModels::unload(VisionMLTask task)
{
if (m_backendType == visp::backend_type::gpu) {
QMutexLocker lock(&m_mutex);
switch (task) {
case VisionMLTask::segmentation:
m_sam = {};
break;
case VisionMLTask::inpainting:
m_migan = {};
break;
case VisionMLTask::background_removal:
m_birefnet = {};
break;
default:
break;
}
}
}
QByteArray VisionModels::modelPath(VisionMLTask task) const
{
QString path = paths.models + modelName(task);
if (!QFile::exists(path)) {
throw std::runtime_error("Model file not found: " + path.toStdString());
}
return path.toUtf8();
}
visp::backend_type VisionModels::backend() const
{
return m_backendType;
}
bool VisionModels::setBackend(visp::backend_type backendType)
{
if (backendType == m_backendType) {
return true;
}
QString err = initialize(backendType);
if (!err.isEmpty()) {
QMessageBox::warning(nullptr,
i18nc("@title:window", "Krita - Vision ML Tools Plugin"),
i18n("Error while trying to switch inference backend.\n") + err);
return false;
}
Q_EMIT backendChanged(m_backendType);
return true;
}
QString const &VisionModels::modelName(VisionMLTask task) const
{
return m_modelName[(int)task];
}
void VisionModels::setModelName(VisionMLTask task, QString const &name)
{
if (modelName(task) == name) {
return;
}
QMutexLocker lock(&m_mutex);
m_modelName[(int)task] = name;
m_config.writeEntry(QString("model_%1").arg((int)task), name);
unloadModels();
Q_EMIT modelNameChanged(task, name);
}
QString VisionModels::backendDeviceDescription() const
{
ggml_backend_dev_t dev = ggml_backend_get_device(m_backend);
char const *name = ggml_backend_dev_name(dev);
char const *desc = ggml_backend_dev_description(dev);
return QString("%1 [%2]").arg(QString(desc).trimmed(), name);
}
void VisionModels::unloadModels()
{
m_sam = {};
m_birefnet = {};
m_migan = {};
}
void VisionModels::cleanUp()
{
unloadModels();
m_backend = {};
}
VisionMLImage VisionMLImage::prepare(KisPaintDevice const &device, QRect bounds)
{
VisionMLImage result;
if (bounds.isEmpty()) {
bounds = device.exactBounds();
}
if (bounds.isEmpty()) {
return result;
}
KoColorSpace const *cs = device.colorSpace();
if (cs->pixelSize() == 4 && cs->id() == "RGBA") {
result.data = QImage(bounds.width(), bounds.height(), QImage::Format_ARGB32);
device.readBytes(result.data.bits(), bounds);
} else {
result.data = device.convertToQImage(nullptr, bounds);
}
result.view.extent = {result.data.width(), result.data.height()};
result.view.stride = result.data.bytesPerLine();
result.view.format = visp::image_format::bgra_u8;
result.view.data = result.data.bits();
return result;
}
QImage VisionMLImage::convertToQImage(visp::image_view const &img, QRect b)
{
if (img.format != visp::image_format::rgba_u8) {
throw std::runtime_error("Unsupported image format for conversion to QImage");
}
if (b.isEmpty()) {
b = QRect(0, 0, img.extent[0], img.extent[1]);
}
QImage result(b.width(), b.height(), QImage::Format_RGBA8888);
size_t rowSize = b.width() * n_bytes(img.format);
size_t rowStride = img.extent[0] * n_bytes(img.format);
size_t rowOffset = b.x() * n_bytes(img.format);
for (int y = 0; y < b.height(); ++y) {
memcpy(result.scanLine(y), ((uint8_t const *)img.data) + (y + b.y()) * rowStride + rowOffset, rowSize);
}
return result;
}
VisionMLBackendWidget::VisionMLBackendWidget(QSharedPointer<VisionModels> shared, bool showDevice, QWidget *parent)
: KisOptionCollectionWidgetWithHeader(i18n("Backend"), parent)
, m_shared(std::move(shared))
{
QWidget *widget = new QWidget;
QHBoxLayout *layout = new QHBoxLayout(widget);
layout->setContentsMargins(0, 0, 0, 0);
KisOptionButtonStrip *strip = new KisOptionButtonStrip;
m_cpuButton = strip->addButton(i18n("CPU"));
m_gpuButton = strip->addButton(i18n("GPU"));
if (!visp::backend_is_available(visp::backend_type::cpu)) {
m_cpuButton->setEnabled(false);
m_cpuButton->setToolTip(i18n("CPU backend not available, hardware is not supported"));
}
if (!visp::backend_is_available(visp::backend_type::gpu)) {
m_gpuButton->setEnabled(false);
m_gpuButton->setToolTip(i18n("GPU backend not available, no supported devices found"));
}
layout->addWidget(strip);
if (showDevice) {
m_deviceLabel = new QLabel;
layout->addWidget(m_deviceLabel);
}
setPrimaryWidget(widget);
connect(strip, SIGNAL(buttonToggled(KoGroupButton *, bool)), this, SLOT(switchBackend(KoGroupButton *, bool)));
connect(m_shared.get(), SIGNAL(backendChanged(visp::backend_type)), this, SLOT(updateBackend(visp::backend_type)));
updateBackend(m_shared->backend());
}
void VisionMLBackendWidget::updateBackend(visp::backend_type backend)
{
m_cpuButton->setChecked(backend == visp::backend_type::cpu);
m_gpuButton->setChecked(backend == visp::backend_type::gpu);
if (m_deviceLabel) {
m_deviceLabel->setText(QString(m_shared->backendDeviceDescription()).trimmed());
}
}
void VisionMLBackendWidget::switchBackend(KoGroupButton *button, bool checked)
{
if (checked) {
bool success = m_shared->setBackend(button == m_cpuButton ? visp::backend_type::cpu : visp::backend_type::gpu);
if (!success) {
button->setEnabled(false);
KoGroupButton *prev = m_shared->backend() == visp::backend_type::cpu ? m_cpuButton : m_gpuButton;
bool blocked = prev->blockSignals(true);
prev->setChecked(true);
prev->blockSignals(blocked);
}
}
}
VisionMLModelSelect::VisionMLModelSelect(QSharedPointer<VisionModels> models,
VisionMLTask task,
bool showFolderButton,
QWidget *parent)
: KisOptionCollectionWidgetWithHeader(i18n("Model"), parent)
, m_shared(std::move(models))
, m_task(task)
{
QWidget *widget = new QWidget;
QHBoxLayout *layout = new QHBoxLayout(widget);
layout->setContentsMargins(0, 0, 0, 0);
m_select = new QComboBox;
updateModels();
updateModel(m_task, m_shared->modelName(m_task));
connect(m_select, SIGNAL(currentIndexChanged(int)), this, SLOT(switchModel(int)));
connect(m_shared.get(), &VisionModels::modelNameChanged, this, &VisionMLModelSelect::updateModel);
layout->addWidget(m_select);
if (showFolderButton) {
QToolButton *folderButton = new QToolButton;
folderButton->setIcon(KisIconUtils::loadIcon("document-open"));
folderButton->setFixedSize(24, 24);
folderButton->setToolTip(i18n("Open models folder"));
connect(folderButton, &QToolButton::clicked, this, &VisionMLModelSelect::openModelsFolder);
layout->addWidget(folderButton);
}
m_fileWatcher = new QFileSystemWatcher(this);
m_fileWatcher->addPath(findModelPath(m_task));
connect(m_fileWatcher, &QFileSystemWatcher::directoryChanged, this, &VisionMLModelSelect::updateModels);
setPrimaryWidget(widget);
}
void VisionMLModelSelect::updateModels()
{
m_select->blockSignals(true);
QVariant current = m_select->currentData();
m_select->clear();
auto addModels = [this](const QString &arch) {
QDir modelDir(paths.models + arch);
QStringList modelFiles = modelDir.entryList(QStringList() << "*.gguf", QDir::Files);
for (QString &file : modelFiles) {
QString fullName = arch + "/" + file;
m_select->addItem(file.replace(".gguf", ""), fullName);
}
};
switch (m_task) {
case VisionMLTask::segmentation:
addModels("sam");
break;
case VisionMLTask::background_removal:
addModels("birefnet");
break;
case VisionMLTask::inpainting:
addModels("migan");
break;
default:
qWarning() << "Unknown VisionMLTask" << (int)m_task;
return;
}
m_select->blockSignals(false);
if (current.isValid()) {
updateModel(m_task, current.toString());
}
}
void VisionMLModelSelect::switchModel(int index)
{
if (index < 0 || index >= m_select->count()) {
return;
}
QString modelName = m_select->itemData(index).toString();
m_shared->setModelName(m_task, modelName);
}
void VisionMLModelSelect::updateModel(VisionMLTask task, QString const &name)
{
if (m_task != task) {
return;
}
int index = m_select->findData(name);
if (index != -1) {
m_select->setCurrentIndex(index);
} else {
m_select->setCurrentIndex(0);
}
}
void VisionMLModelSelect::openModelsFolder()
{
QString folder = findModelPath(m_task);
if (!QDesktopServices::openUrl(QUrl::fromLocalFile(folder))) {
QMessageBox::warning(nullptr,
i18nc("@title:window", "Krita - Vision ML Tools Plugin"),
i18n("Failed to open folder: ") + folder);
}
}
VisionMLErrorReporter::VisionMLErrorReporter(QObject *parent)
: QObject(parent)
{
connect(this, &VisionMLErrorReporter::errorOccurred, this, &VisionMLErrorReporter::showError, Qt::QueuedConnection);
}
void VisionMLErrorReporter::showError(QString const &message) const
{
QMessageBox::warning(nullptr,
i18nc("@title:window", "Krita - Vision ML Tools Plugin"),
i18n("Error during image processing: ") + message);
}