Source code for nvflare.client.converter_utils

# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

from nvflare.app_common.abstract.params_converter import ParamsConverter
from nvflare.client.config import ExchangeFormat
from nvflare.fuel.utils.import_utils import optional_import


def _load_converter(module: str, name: str, format_name: str):
    converter_cls, ok = optional_import(module=module, name=name)
    if not ok:
        raise RuntimeError(f"Can't import {name} for {format_name} exchange format")
    return converter_cls


[docs] def create_default_params_converters( server_expected_format: str, params_exchange_format: str, train_task_name: str, eval_task_name: str, submit_model_task_name: str, ) -> Tuple[Optional[ParamsConverter], Optional[ParamsConverter]]: """Create default from/to NVFlare converters for common Client API formats.""" if server_expected_format != ExchangeFormat.NUMPY: return None, None if params_exchange_format == ExchangeFormat.PYTORCH: numpy_to_pt = _load_converter( module="nvflare.app_opt.pt.numpy_params_converter", name="NumpyToPTParamsConverter", format_name=ExchangeFormat.PYTORCH, ) pt_to_numpy = _load_converter( module="nvflare.app_opt.pt.numpy_params_converter", name="PTToNumpyParamsConverter", format_name=ExchangeFormat.PYTORCH, ) return ( numpy_to_pt([train_task_name, eval_task_name]), pt_to_numpy([train_task_name, submit_model_task_name]), ) if params_exchange_format == ExchangeFormat.KERAS_LAYER_WEIGHTS: numpy_to_keras = _load_converter( module="nvflare.app_opt.tf.params_converter", name="NumpyToKerasModelParamsConverter", format_name=ExchangeFormat.KERAS_LAYER_WEIGHTS, ) keras_to_numpy = _load_converter( module="nvflare.app_opt.tf.params_converter", name="KerasModelToNumpyParamsConverter", format_name=ExchangeFormat.KERAS_LAYER_WEIGHTS, ) return ( numpy_to_keras([train_task_name, eval_task_name]), keras_to_numpy([train_task_name, submit_model_task_name]), ) return None, None