Source code for nvflare.client.api

# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from threading import Lock
from typing import Any, Dict, Optional, Union

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.app_common.abstract.fl_model import FLModel

# this import is to let existing scripts import client.api
from .api_context import ClientAPIType  # noqa: F401
from .api_context import APIContext

global_context_lock = Lock()
context_dict = {}
default_context = None


[docs] def get_context(ctx: Optional[APIContext] = None) -> APIContext: """Gets an APIContext. Args: ctx (Optional[APIContext]): The context to use, if None means use default context. Defaults to None. Raises: RuntimeError: if can't get a valid APIContext. Returns: An APIContext. """ if ctx: return ctx elif default_context: return default_context else: raise RuntimeError("APIContext is None")
[docs] def init(rank: Optional[Union[str, int]] = None, config_file: Optional[str] = None) -> APIContext: """Initializes NVFlare Client API environment. Args: rank (str): local rank of the process. It is only useful when the training script has multiple worker processes. (for example multi GPU) config_file (str): client api configuration. Returns: APIContext """ # subsequent logic assumes rank is a string if rank is not None: if isinstance(rank, int): rank = str(rank) elif isinstance(rank, str): pass else: raise ValueError(f"rank must be a string or an integer but got {type(rank)}") with global_context_lock: global context_dict global default_context local_ctx = context_dict.get((rank, config_file)) if local_ctx is None: local_ctx = APIContext(rank=rank, config_file=config_file) context_dict[(rank, config_file)] = local_ctx default_context = local_ctx else: logging.warning( "Warning: called init() more than once with same parameters." "The subsequence calls are ignored" ) return local_ctx
[docs] def receive(timeout: Optional[float] = None, ctx: Optional[APIContext] = None) -> Optional[FLModel]: """Receives model from NVFlare side. Returns: An FLModel received. """ local_ctx = get_context(ctx) return local_ctx.api.receive(timeout)
[docs] def send(model: FLModel, clear_cache: bool = True, ctx: Optional[APIContext] = None) -> None: """Sends the model to NVFlare side. Args: model (FLModel): The FLModel object to be sent. clear_cache (bool): Whether to clear the cache after send. """ if not isinstance(model, FLModel): raise TypeError("model needs to be an instance of FLModel") local_ctx = get_context(ctx) return local_ctx.api.send(model, clear_cache)
[docs] def system_info(ctx: Optional[APIContext] = None) -> Dict: """Gets NVFlare system information. System information will be available after a valid FLModel is received. It does not retrieve information actively. Note: system information includes job id and site name. Returns: A dict of system information. """ local_ctx = get_context(ctx) return local_ctx.api.system_info()
[docs] def get_config(ctx: Optional[APIContext] = None) -> Dict: """Gets the ClientConfig dictionary. Returns: A dict of the configuration used in Client API. """ local_ctx = get_context(ctx) return local_ctx.api.get_config()
[docs] def get_job_id(ctx: Optional[APIContext] = None) -> str: """Gets job id. Returns: The current job id. """ local_ctx = get_context(ctx) return local_ctx.api.get_job_id()
[docs] def get_site_name(ctx: Optional[APIContext] = None) -> str: """Gets site name. Returns: The site name of this client. """ local_ctx = get_context(ctx) return local_ctx.api.get_site_name()
[docs] def get_task_name(ctx: Optional[APIContext] = None) -> str: """Gets task name. Returns: The task name. """ local_ctx = get_context(ctx) return local_ctx.api.get_task_name()
[docs] def is_running(ctx: Optional[APIContext] = None) -> bool: """Returns whether the NVFlare system is up and running. Returns: True, if the system is up and running. False, otherwise. """ local_ctx = get_context(ctx) return local_ctx.api.is_running()
[docs] def is_train(ctx: Optional[APIContext] = None) -> bool: """Returns whether the current task is a training task. Returns: True, if the current task is a training task. False, otherwise. """ local_ctx = get_context(ctx) return local_ctx.api.is_train()
[docs] def is_evaluate(ctx: Optional[APIContext] = None) -> bool: """Returns whether the current task is an evaluate task. Returns: True, if the current task is an evaluate task. False, otherwise. """ local_ctx = get_context(ctx) return local_ctx.api.is_evaluate()
[docs] def is_submit_model(ctx: Optional[APIContext] = None) -> bool: """Returns whether the current task is a submit_model task. Returns: True, if the current task is a submit_model. False, otherwise. """ local_ctx = get_context(ctx) return local_ctx.api.is_submit_model()
[docs] def log(key: str, value: Any, data_type: AnalyticsDataType, ctx: Optional[APIContext] = None, **kwargs): """Logs a key value pair. We suggest users use the high-level APIs in nvflare/client/tracking.py Args: key (str): key string. value (Any): value to log. data_type (AnalyticsDataType): the data type of the "value". kwargs: additional arguments to be included. Returns: whether the key value pair is logged successfully """ local_ctx = get_context(ctx) return local_ctx.api.log(key, value, data_type, **kwargs)
[docs] def clear(ctx: Optional[APIContext] = None): """Clears the cache.""" local_ctx = get_context(ctx) return local_ctx.api.clear()
[docs] def shutdown(ctx: Optional[APIContext] = None): """Releases all threads and resources used by the API and stops operation.""" local_ctx = get_context(ctx) return local_ctx.api.shutdown()