YYour Namecommit message
d7290920创建于 2024年10月19日历史提交
import argparse
import logging

import numpy as np
import torch
import os
from transformers import AutoConfig, FlaxAutoModelForCausalLM

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

model_path = "./distilgpt2-base-pretrained-he"
save_directory = "./tmp/flax/"

config_path = os.path.join(model_path, 'config.json')

# Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
config = AutoConfig.from_pretrained(config_path)
model = FlaxAutoModelForCausalLM.from_pretrained(model_path, from_pt=True, config=config)
model.save_pretrained(save_directory)