Source code for nvflare.app_opt.pt.model_persistence_format_manager
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import torch
from nvflare.apis.dxo import MetaKey
from nvflare.app_common.abstract.model import (
ModelLearnable,
ModelLearnableKey,
make_model_learnable,
validate_model_learnable,
)
from nvflare.app_common.app_constant import ModelFormat
[docs]
class PTModelPersistenceFormatManager(object):
PERSISTENCE_KEY_MODEL = "model"
PERSISTENCE_KEY_TRAIN_CONF = "train_conf"
PERSISTENCE_KEY_META_PROPS = "meta_props"
def __init__(self, data: dict, default_train_conf=None):
"""Manage the format for model persistence.
Args:
data (dict): either the dictionary mapping variables to values or a dict of dict.
default_train_conf (dict, optional): configuration for train. Defaults to None.
Raises:
TypeError: when data is not a dictionary
"""
if not isinstance(data, dict):
raise TypeError("data must be a dict but got {}".format(type(data)))
self.var_dict = None
self.meta = None
self.train_conf = None
self.other_props = {} # other props from the original data that need to be kept
if self.PERSISTENCE_KEY_MODEL not in data:
# this is a simple weight dict
self.var_dict = data
else:
# dict of dicts
self.var_dict = data[self.PERSISTENCE_KEY_MODEL]
self.meta = data.get(self.PERSISTENCE_KEY_META_PROPS, None)
self.train_conf = data.get(self.PERSISTENCE_KEY_TRAIN_CONF, None)
# we need to keep other props, if any, so they can be kept when persisted
for k, v in data.items():
if k not in [
self.PERSISTENCE_KEY_MODEL,
self.PERSISTENCE_KEY_META_PROPS,
self.PERSISTENCE_KEY_TRAIN_CONF,
]:
self.other_props[k] = v
if not self.train_conf:
self.train_conf = default_train_conf
def _get_processed_vars(self) -> dict:
if self.meta:
return self.meta.get(MetaKey.PROCESSED_KEYS, {})
else:
return {}
[docs]
def to_model_learnable(self, exclude_vars) -> ModelLearnable:
processed_vars = self._get_processed_vars()
weights = {}
for k, v in self.var_dict.items():
if exclude_vars and exclude_vars.search(k):
continue
is_processed = processed_vars.get(k, False)
if is_processed:
weights[k] = v
else:
weights[k] = v.cpu().numpy()
return make_model_learnable(weights, self.meta)
[docs]
def to_persistence_dict(self) -> dict:
processed_vars = self._get_processed_vars()
weights_dict = OrderedDict()
for k, v in self.var_dict.items():
is_processed = processed_vars.get(k, False)
if is_processed:
weights_dict[k] = v
else:
weights_dict[k] = torch.as_tensor(v)
# always use complex format for saving
persistence_dict = OrderedDict()
persistence_dict[self.PERSISTENCE_KEY_MODEL] = weights_dict
if self.meta:
persistence_dict[self.PERSISTENCE_KEY_META_PROPS] = self.meta
if self.train_conf:
persistence_dict[self.PERSISTENCE_KEY_TRAIN_CONF] = self.train_conf
if self.other_props:
for k, v in self.other_props.items():
persistence_dict[k] = v
return persistence_dict
[docs]
def update(self, ml: ModelLearnable):
"""Update the persistence data with the learned values.
Args:
ml (ModelLearnable): updated information to be merged into existing ModelLearnable
"""
err = validate_model_learnable(ml)
if err:
raise ValueError(err)
self.meta = ml.get(ModelLearnableKey.META, None)
# update with value of the model learnable
# note that the original weights that are not learned are still kept!
learned_weights = ml.get(ModelLearnableKey.WEIGHTS, {})
for k, v in learned_weights.items():
self.var_dict[k] = v