Source code for nvflare.app_opt.sklearn.data_loader

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd

pd_readers = {
    "csv": pd.read_csv,
    "xls": pd.read_excel,
    "xlsx": pd.read_excel,
}


def _to_data_tuple(data):
    data_num = data.shape[0]
    # split to feature and label
    x = data.iloc[:, 1:]
    y = data.iloc[:, 0]
    return x.to_numpy(), y.to_numpy(), data_num


[docs] def get_pandas_reader(data_path: str): from nvflare.app_common.utils.file_utils import get_file_format file_format = get_file_format(data_path) reader = pd_readers.get(file_format, None) if reader is None: raise ValueError(f"no pandas reader for given file format {file_format}") return reader
[docs] def load_data(data_path: str, require_header: bool = False): reader = get_pandas_reader(data_path) if hasattr(reader, "header") and require_header: data = reader(data_path) else: data = reader(data_path, header=None) return _to_data_tuple(data)
[docs] def load_data_for_range(data_path: str, start: int, end: int, require_header: bool = False): reader = get_pandas_reader(data_path) if hasattr(reader, "skiprows"): data_size = end - start if hasattr(reader, "header") and require_header: data = reader(data_path, skiprows=start, nrows=data_size) else: data = reader(data_path, header=None, skiprows=start, nrows=data_size) else: if hasattr(reader, "header") and require_header: data = reader(data_path).iloc[start:end] else: data = reader(data_path, header=None).iloc[start:end] return _to_data_tuple(data)