Source code for nvflare.app_common.abstract.fl_model

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Dict, Optional, Union

from nvflare.apis.fl_constant import FLMetaKey
from nvflare.fuel.utils.validation_utils import check_object_type

[docs]class ParamsType(str, Enum): FULL = "FULL" DIFF = "DIFF"
[docs]class FLModelConst: PARAMS_TYPE = "params_type" PARAMS = "params" OPTIMIZER_PARAMS = "optimizer_params" METRICS = "metrics" CURRENT_ROUND = "current_round" START_ROUND = "start_round" TOTAL_ROUNDS = "total_rounds" META = "meta"
[docs]class MetaKey(FLMetaKey): pass
[docs]class FLModel: def __init__( self, params_type: Union[None, str, ParamsType] = None, params: Any = None, optimizer_params: Any = None, metrics: Optional[Dict] = None, start_round: Optional[int] = 0, current_round: Optional[int] = None, total_rounds: Optional[int] = None, meta: Optional[Dict] = None, ): """FLModel is a standardize data structure for NVFlare to communicate with external systems. Args: params_type: type of the parameters. It only describes the "params". If params_type is None, params need to be None. If params is provided but params_type is not provided, then it will be treated as FULL. params: model parameters, for example: model weights for deep learning. optimizer_params: optimizer parameters. For many cases, the optimizer parameters don't need to be transferred during FL training. metrics: evaluation metrics such as loss and scores. current_round: the current FL rounds. A round means round trip between client/server during training. None for inference. total_rounds: total number of FL rounds. A round means round trip between client/server during training. None for inference. meta: metadata dictionary used to contain any key-value pairs to facilitate the process. """ if params_type is None: if params is not None: params_type = ParamsType.FULL else: params_type = ParamsType(params_type) if params_type == ParamsType.FULL or params_type == ParamsType.DIFF: if params is None: raise ValueError(f"params must be provided when params_type is {params_type}") self.params_type = params_type self.params = params self.optimizer_params = optimizer_params self.metrics = metrics self.start_round = start_round self.current_round = current_round self.total_rounds = total_rounds if meta is not None: check_object_type("meta", meta, dict) else: meta = {} self.meta = meta self._summary: dict = {} def _add_to_summary(self, kvs: Dict): for key, value in kvs.items(): if value: if isinstance(value, dict): self._summary[key] = len(value) elif isinstance(value, ParamsType): self._summary[key] = value elif isinstance(value, int): self._summary[key] = value else: self._summary[key] = type(value)
[docs] def summary(self): kvs = dict( params=self.params, optimizer_params=self.optimizer_params, metrics=self.metrics, meta=self.meta, params_type=self.params_type, start_round=self.start_round, current_round=self.current_round, total_rounds=self.total_rounds, ) self._add_to_summary(kvs) return self._summary
def __repr__(self): return str(self.summary()) def __str__(self): return str(self.summary())