Source code for runhouse.resources.functionals.mapper

import concurrent.futures
import contextvars
import logging
from typing import Callable, List, Optional, Union

try:
    from tqdm import tqdm
except ImportError:

    def tqdm(*args, **kwargs):
        return args[0]


from runhouse.resources.envs.env import Env
from runhouse.resources.functions import function, Function
from runhouse.resources.hardware.cluster import Cluster

from runhouse.resources.module import Module

logger = logging.getLogger(__name__)


[docs]class Mapper(Module):
[docs] def __init__( self, module: Module = None, method: str = None, replicas: Union[None, int, List[Module]] = None, concurrency=1, **kwargs, ): """ Runhouse Mapper object. It is used for mapping a function or module method over a list of inputs, across a series of replicas. .. note:: To create a Mapper, please use the factory method :func:`mapper`. """ super().__init__(**kwargs) self.module = module self.method = method self.concurrency = concurrency self._num_auto_replicas = None self._auto_replicas = [] self._user_replicas = [] self._last_called = 0 if isinstance(replicas, int): if self.module.system: # Only add replicas if the replicated module is already on a cluster if replicas > self.num_replicas and replicas > 0: self.add_replicas(replicas) else: # Otherwise, store this for later once we've sent the mapper to the cluster self._num_auto_replicas = replicas elif isinstance(replicas, list): self._user_replicas = replicas
@property def replicas(self): return [self.module] + self._auto_replicas + self._user_replicas @property def num_replicas(self): return len(self.replicas) def add_replicas(self, replicas: Union[int, List[Module]]): if isinstance(replicas, int): new_replicas = replicas - self.num_replicas logger.info(f"Adding {new_replicas} replicas") self._add_auto_replicas(new_replicas) else: self._user_replicas.extend(replicas) def drop_replicas(self, num_replicas: int, reap: bool = True): if reap: for replica in self._auto_replicas[-num_replicas:]: replica.system.kill(replica.env.name) self._auto_replicas = self._auto_replicas[:-num_replicas] def _add_auto_replicas(self, num_replicas: int): self._auto_replicas.extend(self.module.replicate(num_replicas)) def increment_counter(self): self._last_called += 1 if self._last_called >= len(self.replicas): self._last_called = 0 return self._last_called
[docs] def to( self, system: Union[str, Cluster], env: Optional[Union[str, List[str], Env]] = None, name: Optional[str] = None, force_install: bool = False, ): """Put a copy of the Mapper and its internal module on the destination system and env, and return the new mapper. Example: >>> local_mapper = rh.mapper(my_module, replicas=2) >>> cluster_mapper = local_mapper.to(my_cluster) """ if not self.module.system: # Note that we don't pass name here, as this is the name meant for the mapper self.module = self.module.to( system=system, env=env, force_install=force_install ) remote_mapper = super().to( system=system, env=env, name=name, force_install=force_install ) if isinstance(self._num_auto_replicas, int): remote_mapper.add_replicas(self._num_auto_replicas) return remote_mapper
[docs] def map(self, *args, method: Optional[str] = None, retries: int = 0, **kwargs): """Map the function or method over a list of arguments. Example: >>> mapper = rh.mapper(local_sum, replicas=2).to(my_cluster) >>> mapper.map([1, 2], [1, 4], [2, 3], retries=3) >>> # If you're mapping over a remote module, you can choose not to specify which method to call initially >>> # so you can call different methods in different maps (note that our replicas can hold state!) >>> # Note that in the example below we're careful to use the same number of replicas as data >>> # we have to process, or the state in a replica would be overwritten by the next call. >>> shards = len(source_paths) >>> mapper = rh.mapper(remote_module, replicas=shards).to(my_cluster) >>> mapper.map(*source_paths, method="load_data") >>> mapper.map([]*shards, method="process_data") # Calls each replica once with no args >>> mapper.map(*output_paths, method="save_data") """ # Don't stream logs by default unless the mapper is remote (i.e. mediating the mapping) return self.starmap( arg_list=zip(*args), method=method, retries=retries, **kwargs )
[docs] def starmap( self, arg_list: List, method: Optional[str] = None, retries: int = 0, **kwargs ): """Like :func:`map` except that the elements of the iterable are expected to be iterables that are unpacked as arguments. An iterable of ``[(1,2), (3, 4)]`` results in ``func(1,2), func(3,4)]``. Example: >>> def local_sum(arg1, arg2, arg3): >>> return arg1 + arg2 + arg3 >>> >>> remote_fn = rh.function(local_sum).to(my_cluster) >>> mapper = rh.mapper(remote_fn, replicas=2) >>> arg_list = [(1,2), (3, 4)] >>> # runs the function twice, once with args (1, 2) and once with args (3, 4) >>> mapper.starmap(arg_list) """ # Don't stream logs by default unless the mapper is remote (i.e. mediating the mapping) if self.system and not self.system.on_this_cluster(): kwargs["stream_logs"] = kwargs.get("stream_logs", True) else: kwargs["stream_logs"] = kwargs.get("stream_logs", False) retry_list = [] def call_method_on_replica(job, retry=True): replica, method_name, context, argies, kwargies = job # reset context for var, value in context.items(): var.set(value) try: return getattr(replica, method_name)(*argies, **kwargies) except Exception as e: logger.error(f"Error running {method_name} on {replica.name}: {e}") if retry: retry_list.append(job) else: return e context = contextvars.copy_context() jobs = [ ( self.replicas[self.increment_counter()], method or self.method, context, args, kwargs, ) for args in arg_list ] results = [] max_threads = round(self.concurrency * self.num_replicas) with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: futs = [ executor.submit(call_method_on_replica, job, retries > 0) for job in jobs ] for fut in tqdm(concurrent.futures.as_completed(futs), total=len(jobs)): results.extend([fut.result()]) for i in range(retries): if len(retry_list) == 0: break logger.info(f"Retry {i}: {len(retry_list)} failed jobs") jobs, retry_list = retry_list, [] retry = i != retries - 1 results.append( list( tqdm( executor.map(call_method_on_replica, jobs, retry), total=len(jobs), ) ) ) return results
# TODO should we add an async version of this for when we're on the cluster? # async def call_method_on_args(argies): # return getattr(self.replicas[self.increment_counter()], self.method)(*argies, **kwargs) # # async def gather(): # return await asyncio.gather( # *[ # call_method_on_args(args) # for args in zip(*args) # ] # ) # return asyncio.run(gather())
[docs] def call(self, *args, method: Optional[str] = None, **kwargs): """Call the function or method on a single replica. Example: >>> def local_sum(arg1, arg2, arg3): >>> return arg1 + arg2 + arg3 >>> >>> remote_fn = rh.function(local_sum).to(my_cluster) >>> mapper = rh.mapper(remote_fn, replicas=2) >>> for i in range(10): >>> mapper.call(i, 1, 2) >>> # output: 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, run in round-robin replica order """ return getattr(self.replicas[self.increment_counter()], method or self.method)( *args, **kwargs )
[docs]def mapper( module: Union[Module, Callable], method: Optional[str] = None, replicas: Union[None, int, List[Module]] = None, concurrency: int = 1, **kwargs, ) -> Mapper: """ A factory method for creating Mapper modules. A mapper is a module that can map a function or module method over a list of inputs in various ways. Args: module (Module): The module or function to be mapped. method (Optional[str], optional): The method of the module to be called. If the module is already a callable, this value defaults to ``"call"``. concurrency (int, optional): The number of concurrent calls to each replica, executed in separate threads. Defaults to 1. replicas (Optional[List[Module]], optional): List of user-specified replicas, or an int specifying the number of replicas to be automatically created. Defaults to None. Returns: Mapper: The resulting Mapper object. Example: >>> def local_sum(arg1, arg2, arg3): >>> return arg1 + arg2 + arg3 >>> >>> # Option 1: Pass a function directly to the mapper, and send both to the cluster >>> mapper = rh.mapper(local_sum, replicas=2).to(my_cluster) >>> mapper.map([1, 2], [1, 4], [2, 3]) >>> # Option 2: Create a remote module yourself and pass it to the mapper, which is still local >>> remote_fn = rh.function(local_sum).to(my_cluster, env=my_fn_env) >>> mapper = rh.mapper(remote_fn, replicas=2) >>> mapper.map([1, 2], [1, 4], [2, 3]) >>> # output: [4, 9] >>> # Option 3: Create a remote module and mapper for greater flexibility, and send both to the cluster >>> # You can map over a "class" module (stateless) or an "instance" module to preserve state >>> remote_class = rh.module(cls=MyClass).to(system=cluster, env=my_module_env) >>> stateless_mapper = rh.mapper(remote_class, method="my_class_method", replicas=2).to(cluster) >>> mapper.map([1, 2], [1, 4], [2, 3]) >>> remote_app = remote_class() >>> stateful_mapper = rh.mapper(remote_app, method="my_instance_method", replicas=2).to(cluster) >>> mapper.map([1, 2], [1, 4], [2, 3]) """ if callable(module) and not isinstance(module, Module): module = function(module, **kwargs) if isinstance(module, Function): method = method or "call" return Mapper( module=module, method=method, replicas=replicas, concurrency=concurrency, **kwargs, )