{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "A100"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aCl-IzLoDr2H"
},
"outputs": [],
"source": [
"!pip install -U transformers mamba-ssm"
]
},
{
"cell_type": "markdown",
"source": [
"# Load Models"
],
"metadata": {
"id": "SpRo_KJIRsxv"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"# Load tokenizer and model\n",
"tokenizer = AutoTokenizer.from_pretrained(\"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\")\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" \"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16\",\n",
" torch_dtype=torch.bfloat16,\n",
" trust_remote_code=True,\n",
" device_map=\"auto\"\n",
")\n"
],
"metadata": {
"id": "waveliieEI1n"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Define Input with Tools"
],
"metadata": {
"id": "xjVkqaSdRx0_"
}
},
{
"cell_type": "code",
"source": [
"from transformers.utils import get_json_schema\n",
"\n",
"def multiply(a: float, b: float):\n",
" \"\"\"\n",
" A function that multiplies two numbers\n",
"\n",
" Args:\n",
" a: The first number to multiply\n",
" b: The second number to multiply\n",
" \"\"\"\n",
" return a * b\n",
"\n",
"messages = [\n",
" {\"role\": \"user\", \"content\": \"what is 2.0909090923 x 0.897987987\"},\n",
"]\n",
"\n",
"tokenized_chat = tokenizer.apply_chat_template(\n",
" messages,\n",
" tools=[\n",
" multiply\n",
" ],\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\"\n",
").to(model.device)\n"
],
"metadata": {
"id": "zxZZ7iMZETsw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Inference"
],
"metadata": {
"id": "SVBAG3dLRw4v"
}
},
{
"cell_type": "code",
"source": [
"outputs = model.generate(\n",
" tokenized_chat,\n",
" max_new_tokens=1024,\n",
" temperature=1.0,\n",
" top_p=1.0,\n",
" eos_token_id=tokenizer.eos_token_id\n",
")\n",
"print(tokenizer.decode(outputs[0]))"
],
"metadata": {
"id": "BKYqPT5ORDx3"
},
"execution_count": null,
"outputs": []
}
]
}