Source code for hpbandster.core.dispatcher

import threading
import logging
import queue
import time

import Pyro4


[docs]class Job(object): def __init__(self, id, **kwargs): self.id = id self.kwargs = kwargs self.timestamps = {} self.result = None self.exception = None self.worker_name = None
[docs] def time_it(self, which_time): self.timestamps[which_time] = time.time()
def __repr__(self): return(\ "job_id: " +str(self.id) + "\n" + \ "kwargs: " + str(self.kwargs) + "\n" + \ "result: " + str(self.result)+ "\n" +\ "exception: "+ str(self.exception) + "\n" )
[docs] def recreate_from_run(self, run): run.config_id run.budget run.error_logs run.loss run.info run.time_stamps
[docs]class Worker(object): def __init__(self, name, uri): self.name = name self.proxy = Pyro4.Proxy(uri) self.runs_job = None
[docs] def is_alive(self): try: self.proxy._pyroReconnect(1) except Pyro4.errors.ConnectionClosedError: return False except: raise return(True)
[docs] def shutdown(self): self.proxy.shutdown()
[docs] def is_busy(self): return(self.proxy.is_busy())
def __repr__(self): return(self.name)
[docs]class Dispatcher(object): """ The dispatcher is responsible for assigning tasks to free workers, report results back to the master and communicate to the nameserver. """ def __init__(self, new_result_callback, run_id='0', ping_interval=10, nameserver='localhost', nameserver_port=None, host=None, logger=None, queue_callback=None): """ Parameters ---------- new_result_callback: function function that will be called with a `Job instance <hpbandster.core.dispatcher.Job>`_ as argument. From the `Job` the result can be read and e.g. logged. run_id: str unique run_id associated with the HPB run ping_interval: int how often to ping for workers (in seconds) nameserver: str address of the Pyro4 nameserver nameserver_port: int port of Pyro4 nameserver host: str ip (or name that resolves to that) of the network interface to use logger: logging.Logger logger-instance for info and debug queue_callback: function gets called with the number of workers in the pool on every update-cycle """ self.new_result_callback = new_result_callback self.queue_callback = queue_callback self.run_id = run_id self.nameserver = nameserver self.nameserver_port = nameserver_port self.host = host self.ping_interval = int(ping_interval) self.shutdown_all_threads = False if logger is None: self.logger = logging.getLogger('hpbandster') else: self.logger = logger self.worker_pool = {} self.waiting_jobs = queue.Queue() self.running_jobs = {} self.idle_workers = set() self.thread_lock = threading.Lock() self.runner_cond = threading.Condition(self.thread_lock) self.discover_cond = threading.Condition(self.thread_lock) self.pyro_id="hpbandster.run_%s.dispatcher"%self.run_id
[docs] def run(self): with self.discover_cond: t1 = threading.Thread(target=self.discover_workers, name='discover_workers') t1.start() self.logger.info('DISPATCHER: started the \'discover_worker\' thread') t2 = threading.Thread(target=self.job_runner, name='job_runner') t2.start() self.logger.info('DISPATCHER: started the \'job_runner\' thread') self.pyro_daemon = Pyro4.core.Daemon(host=self.host) with Pyro4.locateNS(host=self.nameserver, port=self.nameserver_port) as ns: uri = self.pyro_daemon.register(self, self.pyro_id) ns.register(self.pyro_id, uri) self.logger.info("DISPATCHER: Pyro daemon running on %s"%(self.pyro_daemon.locationStr)) self.pyro_daemon.requestLoop() with self.discover_cond: self.shutdown_all_threads = True self.logger.info('DISPATCHER: Dispatcher shutting down') self.runner_cond.notify_all() self.discover_cond.notify_all() with Pyro4.locateNS(self.nameserver, port=self.nameserver_port) as ns: ns.remove(self.pyro_id) t1.join() self.logger.debug('DISPATCHER: \'discover_worker\' thread exited') t2.join() self.logger.debug('DISPATCHER: \'job_runner\' thread exited') self.logger.info('DISPATCHER: shut down complete')
[docs] def shutdown_all_workers(self, rediscover=False): with self.discover_cond: for worker in self.worker_pool.values(): worker.shutdown() if rediscover: time.sleep(1) self.discover_cond.notify()
[docs] def shutdown(self, shutdown_workers=False): if shutdown_workers: self.shutdown_all_workers() with self.runner_cond: self.pyro_daemon.shutdown()
[docs] @Pyro4.expose @Pyro4.oneway def trigger_discover_worker(self): #time.sleep(1) self.logger.info("DISPATCHER: A new worker triggered discover_worker") with self.discover_cond: self.discover_cond.notify()
[docs] def discover_workers(self): self.discover_cond.acquire() sleep_interval = 1 while True: self.logger.debug('DISPATCHER: Starting worker discovery') update = False with Pyro4.locateNS(host=self.nameserver, port=self.nameserver_port) as ns: worker_names = ns.list(prefix="hpbandster.run_%s.worker."%self.run_id) self.logger.debug("DISPATCHER: Found %i potential workers, %i currently in the pool."%(len(worker_names), len(self.worker_pool))) for wn, uri in worker_names.items(): if not wn in self.worker_pool: w = Worker(wn, uri) if not w.is_alive(): self.logger.debug('DISPATCHER: skipping dead worker, %s'%wn) continue update = True self.logger.info('DISPATCHER: discovered new worker, %s'%wn) self.worker_pool[wn] = w # check the current list of workers crashed_jobs = set() all_workers = list(self.worker_pool.keys()) for wn in all_workers: # remove dead entries from the nameserver if not self.worker_pool[wn].is_alive(): self.logger.info('DISPATCHER: removing dead worker, %s'%wn) update = True # todo check if there were jobs running on that that need to be rescheduled current_job = self.worker_pool[wn].runs_job if not current_job is None: self.logger.info('Job %s was not completed'%str(current_job)) crashed_jobs.add(current_job) del self.worker_pool[wn] self.idle_workers.discard(wn) continue if not self.worker_pool[wn].is_busy(): self.idle_workers.add(wn) # try to submit more jobs if something changed if update: if not self.queue_callback is None: self.discover_cond.release() self.queue_callback(len(self.worker_pool)) self.discover_cond.acquire() self.runner_cond.notify() for crashed_job in crashed_jobs: self.discover_cond.release() self.register_result(crashed_job, {'result': None, 'exception': 'Worker died unexpectedly.'}) self.discover_cond.acquire() self.logger.debug('DISPATCHER: Finished worker discovery') #if (len(self.worker_pool) == 0 ): # ping for new workers if no workers are currently available # self.logger.debug('No workers available! Keep pinging') # self.discover_cond.wait(sleep_interval) # sleep_interval *= 2 #else: self.discover_cond.wait(self.ping_interval) if self.shutdown_all_threads: self.logger.debug('DISPATCHER: discover_workers shutting down') self.runner_cond.notify() self.discover_cond.release() return
[docs] def number_of_workers(self): with self.discover_cond: return(len(self.worker_pool))
[docs] def job_runner(self): self.runner_cond.acquire() while True: while self.waiting_jobs.empty() or len(self.idle_workers) == 0: self.logger.debug('DISPATCHER: jobs to submit = %i, number of idle workers = %i -> waiting!'%(self.waiting_jobs.qsize(), len(self.idle_workers) )) self.runner_cond.wait() self.logger.debug('DISPATCHER: Trying to submit another job.') if self.shutdown_all_threads: self.logger.debug('DISPATCHER: job_runner shutting down') self.discover_cond.notify() self.runner_cond.release() return job = self.waiting_jobs.get() wn = self.idle_workers.pop() worker = self.worker_pool[wn] self.logger.debug('DISPATCHER: starting job %s on %s'%(str(job.id),worker.name)) job.time_it('started') worker.runs_job = job.id worker.proxy.start_computation(self, job.id, **job.kwargs) job.worker_name = wn self.running_jobs[job.id] = job self.logger.debug('DISPATCHER: job %s dispatched on %s'%(str(job.id),worker.name))
[docs] def submit_job(self, id, **kwargs): self.logger.debug('DISPATCHER: trying to submit job %s'%str(id)) with self.runner_cond: job = Job(id, **kwargs) job.time_it('submitted') self.waiting_jobs.put(job) self.logger.debug('DISPATCHER: trying to notify the job_runner thread.') self.runner_cond.notify()
[docs] @Pyro4.expose @Pyro4.callback @Pyro4.oneway def register_result(self, id=None, result=None): self.logger.debug('DISPATCHER: job %s finished'%(str(id))) with self.runner_cond: self.logger.debug('DISPATCHER: register_result: lock acquired') # fill in missing information job = self.running_jobs[id] job.time_it('finished') job.result = result['result'] job.exception = result['exception'] self.logger.debug('DISPATCHER: job %s on %s finished'%(str(job.id),job.worker_name)) self.logger.debug(str(job)) # delete job del self.running_jobs[id] # label worker as idle again try: self.worker_pool[job.worker_name].runs_job = None self.worker_pool[job.worker_name].proxy._pyroRelease() self.idle_workers.add(job.worker_name) # notify the job_runner to check for more jobs to run self.runner_cond.notify() except KeyError: # happens for crashed workers, but we can just continue pass except: raise # call users callback function to register the result # needs to be with the condition released, as the master can call # submit_job quickly enough to cause a dead-lock self.new_result_callback(job)