| |
| import os |
| import yaml |
| import glob |
| import requests |
| from model import make_model_and_optimizer |
| import torch |
| from asteroid import torch_utils |
| from collections import OrderedDict |
|
|
| exp_dir = "exp/tmp" |
| |
| os.makedirs(os.path.join(exp_dir, "checkpoints"), exist_ok=True) |
| |
| if len(glob.glob(os.path.join(exp_dir, "checkpoints", "*.ckpt"))) == 0: |
| r = requests.get( |
| "https://huggingface.co/JunzheJosephZhu/MultiDecoderDPRNN/resolve/main/best-model.ckpt" |
| ) |
| with open(os.path.join(exp_dir, "checkpoints", "best-model.ckpt"), "wb") as handle: |
| handle.write(r.content) |
| |
| conf_path = os.path.join(exp_dir, "conf.yml") |
| if not os.path.exists(conf_path): |
| conf_path = "local/conf.yml" |
| |
| with open(conf_path) as f: |
| train_conf = yaml.safe_load(f) |
| sample_rate = train_conf["data"]["sample_rate"] |
| best_model_path = os.path.join(exp_dir, "checkpoints", "best-model.ckpt") |
| model, _ = make_model_and_optimizer(train_conf, sample_rate=sample_rate) |
| model.eval() |
| checkpoint = torch.load(best_model_path, map_location="cpu") |
| model = torch_utils.load_state_dict_in(checkpoint["state_dict"], model) |
| model_args = {} |
| model_args.update(train_conf["masknet"]) |
| model_args.update(train_conf["filterbank"]) |
| new_state_dict = OrderedDict() |
| for k, v in checkpoint["state_dict"].items(): |
| new_k = k[k.find(".") + 1 :] |
| new_state_dict[new_k] = v |
| checkpoint["state_dict"] = new_state_dict |
| checkpoint["model_name"] = "MultiDecoderDPRNN" |
| checkpoint["sample_rate"] = sample_rate |
| checkpoint["model_args"] = model_args |
| torch.save(checkpoint, "pytorch_model.bin") |
| print(f"saved checkpoint to pytorch_model.bin") |