# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pprint import pprint
from nvflare.ha.overseer_agent import HttpOverseerAgent
[docs]def setup_basic_info():
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--project", type=str, default="example_project", help="project name")
parser.add_argument("-r", "--role", type=str, help="role (server, client or admin)")
parser.add_argument("-n", "--name", type=str, help="globally unique name")
parser.add_argument("-f", "--fl_port", type=str, help="fl port number")
parser.add_argument("-a", "--admin_port", type=str, help="adm port number")
parser.add_argument("-s", "--sleep", type=float, help="sleep (seconds) in heartbeat")
parser.add_argument("-c", "--ca_path", type=str, help="root CA path")
parser.add_argument("-o", "--overseer_url", type=str, help="Overseer URL")
parser.add_argument("-t", "--cert_path", type=str, help="cert path")
parser.add_argument("-v", "--prv_key_path", type=str, help="priviate key path")
args = parser.parse_args()
overseer_agent = HttpOverseerAgent(
overseer_end_point=args.overseer_url,
project=args.project,
role=args.role,
name=args.name,
fl_port=args.fl_port,
admin_port=args.admin_port,
heartbeat_interval=args.sleep,
)
if args.ca_path:
overseer_agent.set_secure_context(
ca_path=args.ca_path, cert_path=args.cert_path, prv_key_path=args.prv_key_path
)
return overseer_agent
[docs]def main():
overseer_agent = setup_basic_info()
overseer_agent.start(simple_callback, conditional_cb=True)
while True:
answer = input("(p)ause/(r)esume/(s)witch/(d)ump/(e)nd? ")
normalized_answer = answer.strip().upper()
if normalized_answer == "P":
overseer_agent.pause()
elif normalized_answer == "R":
overseer_agent.resume()
elif normalized_answer == "E":
overseer_agent.end()
break
elif normalized_answer == "D":
pprint(overseer_agent.overseer_info)
elif normalized_answer == "":
continue
elif normalized_answer[0] == "S":
split = normalized_answer.split()
if len(split) == 2:
sp_index = int(split[1])
else:
print("expect sp index but got nothing. Please provide the sp index to be promoted")
continue
try:
sp = overseer_agent.overseer_info.get("sp_list")[sp_index]
except IndexError:
print("index out of range")
else:
resp = overseer_agent.promote_sp(sp.get("sp_end_point"))
pprint(resp.json())
[docs]def simple_callback(overseer_agent):
print(f"\nGot callback {overseer_agent.get_primary_sp()}")
if __name__ == "__main__":
main()