import contextlib
import logging
import subprocess
import time
import warnings
from pathlib import Path
from typing import Any, Dict
import rich.errors
import yaml
try:
import sky
from sky.backends import backend_utils
except ImportError:
pass
from runhouse.constants import (
DEFAULT_HTTP_PORT,
DEFAULT_HTTPS_PORT,
DEFAULT_SERVER_PORT,
LOCAL_HOSTS,
)
from runhouse.globals import configs, rns_client
from runhouse.resources.hardware.utils import ServerConnectionType
from .cluster import Cluster
logger = logging.getLogger(__name__)
[docs]class OnDemandCluster(Cluster):
RESOURCE_TYPE = "cluster"
RECONNECT_TIMEOUT = 5
DEFAULT_KEYFILE = "~/.ssh/sky-key"
[docs] def __init__(
self,
name,
instance_type: str = None,
num_instances: int = None,
provider: str = None,
default_env: "Env" = None,
dryrun=False,
autostop_mins=None,
use_spot=False,
image_id=None,
memory=None,
disk_size=None,
open_ports=None,
server_host: str = None,
server_port: int = None,
server_connection_type: str = None,
ssl_keyfile: str = None,
ssl_certfile: str = None,
domain: str = None,
den_auth: bool = False,
region=None,
**kwargs, # We have this here to ignore extra arguments when calling from from_config
):
"""
On-demand `SkyPilot <https://github.com/skypilot-org/skypilot/>`_ Cluster.
.. note::
To build a cluster, please use the factory method :func:`cluster`.
"""
super().__init__(
name=name,
default_env=default_env,
server_host=server_host,
server_port=server_port,
server_connection_type=server_connection_type,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
domain=domain,
den_auth=den_auth,
dryrun=dryrun,
**kwargs,
)
self.instance_type = instance_type
self.num_instances = num_instances
self.provider = provider or configs.get("default_provider")
self._autostop_mins = (
autostop_mins
if autostop_mins is not None
else configs.get("default_autostop")
)
self.open_ports = open_ports
self.use_spot = use_spot if use_spot is not None else configs.get("use_spot")
self.image_id = image_id
self.region = region
self.memory = memory
self.disk_size = disk_size
self.stable_internal_external_ips = kwargs.get(
"stable_internal_external_ips", None
)
# Checks if state info is in local sky db, populates if so.
if not dryrun and not self.ips and not self.creds_values:
# Cluster status is set to INIT in the Sky DB right after starting, so we need to refresh once
self._update_from_sky_status(dryrun=False)
@property
def autostop_mins(self):
return self._autostop_mins
@autostop_mins.setter
def autostop_mins(self, mins):
self.check_server()
if self.on_this_cluster():
raise ValueError("Cannot set autostop_mins live on the cluster.")
else:
if self.run_python(["import skypilot"])[0] != 0:
raise ImportError(
"Skypilot must be installed on the cluster in order to set autostop."
)
self.client.set_settings({"autostop_mins": mins})
sky.autostop(self.name, mins, down=True)
self._autostop_mins = mins
def config(self, condensed=True):
config = super().config(condensed)
self.save_attrs_to_config(
config,
[
"instance_type",
"num_instances",
"provider",
"open_ports",
"use_spot",
"image_id",
"region",
"stable_internal_external_ips",
],
)
config["autostop_mins"] = self._autostop_mins
return config
[docs] def endpoint(self, external=False):
try:
self.check_server()
except ValueError:
return None
return super().endpoint(external)
def _copy_sky_yaml_from_cluster(self, abs_yaml_path: str):
if not Path(abs_yaml_path).exists():
Path(abs_yaml_path).parent.mkdir(parents=True, exist_ok=True)
self._rsync("~/.sky/sky_ray.yml", abs_yaml_path, up=False)
# Save SSH info to the ~/.ssh/config
ray_yaml = yaml.safe_load(open(abs_yaml_path, "r"))
backend_utils.SSHConfigHelper.add_cluster(
self.name, [self.address], ray_yaml["auth"]
)
@staticmethod
def relative_yaml_path(yaml_path):
if Path(yaml_path).is_absolute():
yaml_path = "~/.sky/generated/" + Path(yaml_path).name
return yaml_path
def set_connection_defaults(self):
if self.server_connection_type in [
ServerConnectionType.AWS_SSM,
]:
raise ValueError(
f"OnDemandCluster does not support server connection type {self.server_connection_type}"
)
if not self.server_connection_type:
if self.ssl_keyfile or self.ssl_certfile:
self.server_connection_type = ServerConnectionType.TLS
else:
self.server_connection_type = ServerConnectionType.SSH
if self.server_port is None:
if self.server_connection_type == ServerConnectionType.TLS:
self.server_port = DEFAULT_HTTPS_PORT
elif self.server_connection_type == ServerConnectionType.NONE:
self.server_port = DEFAULT_HTTP_PORT
else:
self.server_port = DEFAULT_SERVER_PORT
if (
self.server_connection_type
in [ServerConnectionType.TLS, ServerConnectionType.NONE]
and self.server_host in LOCAL_HOSTS
):
warnings.warn(
f"Server connection type: {self.server_connection_type}, server host: {self.server_host}. "
f"Note that this will require opening an SSH tunnel to forward traffic from"
f" {self.server_host} to the server."
)
self.open_ports = (
[]
if self.open_ports is None
else [self.open_ports]
if isinstance(self.open_ports, (int, str))
else self.open_ports
)
if self.open_ports:
self.open_ports = [str(p) for p in self.open_ports]
if str(self.server_port) in self.open_ports:
if (
self.server_connection_type
in [ServerConnectionType.TLS, ServerConnectionType.NONE]
and not self.den_auth
):
warnings.warn(
"Server is insecure and must be inside a VPC or have `den_auth` enabled to secure it."
)
else:
warnings.warn(
f"Server port {self.server_port} not included in open ports. Note you are responsible for opening "
f"the port or ensure you have access to it via a VPC."
)
else:
# If using HTTP or HTTPS must enable traffic on the relevant port
if self.server_connection_type in [
ServerConnectionType.TLS,
ServerConnectionType.NONE,
]:
if self.server_port:
warnings.warn(
f"No open ports specified. Setting default port {self.server_port} to open."
)
self.open_ports = [str(self.server_port)]
else:
warnings.warn(
f"No open ports specified. Make sure the relevant port is open. "
f"HTTPS default: {DEFAULT_HTTPS_PORT} and HTTP "
f"default: {DEFAULT_HTTP_PORT}."
)
# ----------------- Launch/Lifecycle Methods -----------------
[docs] def is_up(self) -> bool:
"""Whether the cluster is up.
Example:
>>> rh.ondemand_cluster("rh-cpu").is_up()
"""
if self.on_this_cluster():
return True
self._update_from_sky_status(dryrun=False)
return self.address is not None
def _sky_status(self, refresh: bool = True, retry: bool = True):
"""
Get status of Sky cluster.
Return dict looks like:
.. code-block::
{'name': 'sky-cpunode-donny',
'launched_at': 1662317201,
'handle': ResourceHandle(
cluster_name=sky-cpunode-donny,
head_ip=54.211.97.164,
cluster_yaml=/Users/donny/.sky/generated/sky-cpunode-donny.yml,
launched_resources=1x AWS(m6i.2xlarge),
tpu_create_script=None,
tpu_delete_script=None),
'last_use': 'sky cpunode',
'status': <ClusterStatus.UP: 'UP'>,
'autostop': -1,
'metadata': {}}
.. note:: For more information see SkyPilot's :code:`ResourceHandle` `class
<https://github.com/skypilot-org/skypilot/blob/0c2b291b03abe486b521b40a3069195e56b62324/sky/backends/cloud_vm_ray_backend.py#L1457>`_.
"""
if not sky.global_user_state.get_cluster_from_name(self.name):
return None
try:
state = sky.status(cluster_names=[self.name], refresh=refresh)
except rich.errors.LiveError as e:
# We can't have more than one Live display at once, so if we've already launched one (e.g. the first
# time we call status), we can retry without refreshing
if not retry:
raise e
return self._sky_status(refresh=False, retry=False)
# We still need to check if the cluster present in case the cluster went down and was removed from the DB
if len(state) == 0:
return None
return state[0]
def _start_ray_workers(self, ray_port, env):
# Find the internal IP corresponding to the public_head_ip and the rest are workers
internal_head_ip = None
worker_ips = []
stable_internal_external_ips = self._sky_status()[
"handle"
].stable_internal_external_ips
for internal, external in stable_internal_external_ips:
if external == self.address:
internal_head_ip = internal
else:
# NOTE: Using external worker address here because we're running from local
worker_ips.append(external)
logger.debug(f"Internal head IP: {internal_head_ip}")
for host in worker_ips:
logger.info(
f"Starting Ray on worker {host} with head node at {internal_head_ip}:{ray_port}."
)
self.run(
commands=[
f"ray start --address={internal_head_ip}:{ray_port} --disable-usage-stats",
],
node=host,
env=env,
)
time.sleep(5)
def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]):
if cluster_dict and cluster_dict["status"].name in ["UP", "INIT"]:
handle = cluster_dict["handle"]
self.address = handle.head_ip
self.stable_internal_external_ips = handle.stable_internal_external_ips
yaml_path = handle.cluster_yaml
if Path(yaml_path).exists():
ssh_values = backend_utils.ssh_credential_from_yaml(yaml_path)
if not self.creds_values:
from runhouse.resources.secrets.utils import setup_cluster_creds
self._creds = setup_cluster_creds(ssh_values, self.name)
# Add worker IPs if multi-node cluster - keep the head node as the first IP
self.ips = [ext for _, ext in self.stable_internal_external_ips]
else:
self.address = None
self._creds = None
self.stable_internal_external_ips = None
def _update_from_sky_status(self, dryrun: bool = False):
# Try to get the cluster status from SkyDB
if self.is_shared:
# If the cluster is shared can ignore, since the sky data will only be saved on the machine where
# the cluster was initially upped
return
cluster_dict = self._sky_status(refresh=not dryrun)
self._populate_connection_from_status_dict(cluster_dict)
def get_instance_type(self):
if self.instance_type and "--" in self.instance_type: # K8s specific syntax
return self.instance_type
elif (
self.instance_type
and ":" not in self.instance_type
and "CPU" not in self.instance_type
):
return self.instance_type
return None
def accelerators(self):
if (
self.instance_type
and ":" in self.instance_type
and "CPU" not in self.instance_type
):
return self.instance_type
return None
def num_cpus(self):
if (
self.instance_type
and ":" in self.instance_type
and "CPU" in self.instance_type
):
return self.instance_type.rsplit(":", 1)[1]
return None
[docs] def up(self):
"""Up the cluster.
Example:
>>> rh.ondemand_cluster("rh-cpu").up()
"""
if self.on_this_cluster():
return self
supported_providers = ["cheapest"] + list(sky.clouds.CLOUD_REGISTRY)
if self.provider not in supported_providers:
raise ValueError(
f"Cluster provider {self.provider} not supported. Must be one {supported_providers} supported by SkyPilot."
)
task = sky.Task(num_nodes=self.num_instances)
cloud_provider = (
sky.clouds.CLOUD_REGISTRY.from_str(self.provider)
if self.provider != "cheapest"
else None
)
task.set_resources(
sky.Resources(
# TODO: confirm if passing instance type in old way (without --) works when provider is k8s
cloud=cloud_provider,
instance_type=self.get_instance_type(),
accelerators=self.accelerators(),
cpus=self.num_cpus(),
memory=self.memory,
region=self.region or configs.get("default_region"),
disk_size=self.disk_size,
ports=self.open_ports,
image_id=self.image_id,
use_spot=self.use_spot,
)
)
if Path("~/.rh/config.yaml").expanduser().exists():
task.set_file_mounts(
{
"~/.rh/config.yaml": "~/.rh/config.yaml",
}
)
sky.launch(
task,
cluster_name=self.name,
idle_minutes_to_autostop=self._autostop_mins,
down=True,
)
self._update_from_sky_status()
self.restart_server()
return self
[docs] def keep_warm(self, autostop_mins: int = -1):
"""Keep the cluster warm for given number of minutes after inactivity.
Args:
autostop_mins (int): Amount of time (in min) to keep the cluster warm after inactivity.
If set to -1, keep cluster warm indefinitely. (Default: `-1`)
"""
self.autostop_mins = autostop_mins
return self
[docs] def teardown(self):
"""Teardown cluster.
Example:
>>> rh.ondemand_cluster("rh-cpu").teardown()
"""
# Stream logs
sky.down(self.name)
self.address = None
[docs] def teardown_and_delete(self):
"""Teardown cluster and delete it from configs.
Example:
>>> rh.ondemand_cluster("rh-cpu").teardown_and_delete()
"""
self.teardown()
rns_client.delete_configs(resource=self)
[docs] @contextlib.contextmanager
def pause_autostop(self):
"""Context manager to temporarily pause autostop.
Example:
>>> with rh.ondemand_cluster.pause_autostop():
>>> rh.ondemand_cluster.run(["python train.py"])
"""
sky.autostop(self.name, idle_minutes=-1)
yield
sky.autostop(self.name, idle_minutes=self._autostop_mins, down=True)
# ----------------- SSH Methods ----------------- #
[docs] @staticmethod
def cluster_ssh_key(path_to_file):
"""Retrieve SSH key for the cluster.
Example:
>>> ssh_priv_key = rh.ondemand_cluster("rh-cpu").cluster_ssh_key("~/.ssh/id_rsa")
"""
try:
f = open(path_to_file, "r")
private_key = f.read()
return private_key
except FileNotFoundError:
raise Exception(f"File with ssh key not found in: {path_to_file}")
[docs] def ssh(self, node: str = None):
"""SSH into the cluster. If no node is specified, will SSH onto the head node.
Example:
>>> rh.ondemand_cluster("rh-cpu").ssh()
>>> rh.ondemand_cluster("rh-cpu", node="3.89.174.234").ssh()
"""
if self.provider == "kubernetes":
command = f"kubectl get pods | grep {self.name}"
try:
output = subprocess.check_output(command, shell=True, text=True)
lines = output.strip().split("\n")
if lines:
pod_name = lines[0].split()[0]
else:
logger.info("No matching pods found.")
except subprocess.CalledProcessError as e:
raise Exception(f"Error: {e}")
cmd = f"kubectl exec -it {pod_name} -- /bin/bash"
subprocess.run(cmd, shell=True, check=True)
else:
# If SSHing onto a specific node, which requires the default sky public key for verification
from runhouse.resources.hardware.sky_ssh_runner import SkySSHRunner, SshMode
ssh_user = self.creds_values.get("ssh_user")
sky_key = Path(
self.creds_values.get("ssh_private_key", self.DEFAULT_KEYFILE)
).expanduser()
if not sky_key.exists():
raise FileNotFoundError(f"Expected default sky key in path: {sky_key}")
runner = SkySSHRunner(
ip=node or self.address,
ssh_user=ssh_user,
port=self.ssh_port,
ssh_private_key=sky_key,
)
subprocess.run(
runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
)