13a81c2f创建于 2025年12月15日历史提交
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aCl-IzLoDr2H"
      },
      "outputs": [],
      "source": [
        "!pip install -U transformers mamba-ssm"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Load Models"
      ],
      "metadata": {
        "id": "SpRo_KJIRsxv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "# Load tokenizer and model\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\")\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n",
        "    torch_dtype=torch.bfloat16,\n",
        "    trust_remote_code=True,\n",
        "    device_map=\"auto\"\n",
        ")\n"
      ],
      "metadata": {
        "id": "waveliieEI1n"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Define Input with Tools"
      ],
      "metadata": {
        "id": "xjVkqaSdRx0_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers.utils import get_json_schema\n",
        "\n",
        "def multiply(a: float, b: float):\n",
        "    \"\"\"\n",
        "    A function that multiplies two numbers\n",
        "\n",
        "    Args:\n",
        "        a: The first number to multiply\n",
        "        b: The second number to multiply\n",
        "    \"\"\"\n",
        "    return a * b\n",
        "\n",
        "messages = [\n",
        "    {\"role\": \"user\", \"content\": \"what is 2.0909090923 x 0.897987987\"},\n",
        "]\n",
        "\n",
        "tokenized_chat = tokenizer.apply_chat_template(\n",
        "    messages,\n",
        "    tools=[\n",
        "        multiply\n",
        "    ],\n",
        "    tokenize=True,\n",
        "    add_generation_prompt=True,\n",
        "    return_tensors=\"pt\"\n",
        ").to(model.device)\n"
      ],
      "metadata": {
        "id": "zxZZ7iMZETsw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Inference"
      ],
      "metadata": {
        "id": "SVBAG3dLRw4v"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "outputs = model.generate(\n",
        "    tokenized_chat,\n",
        "    max_new_tokens=1024,\n",
        "    temperature=1.0,\n",
        "    top_p=1.0,\n",
        "    eos_token_id=tokenizer.eos_token_id\n",
        ")\n",
        "print(tokenizer.decode(outputs[0]))"
      ],
      "metadata": {
        "id": "BKYqPT5ORDx3"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}