Source code for radiosim.simulations.simulations

import numpy as np
from numpy.random import default_rng
from rich.console import Group
from rich.live import Live
from rich.progress import (
    BarColumn,
    Progress,
    TextColumn,
    TimeElapsedColumn,
)

from radiosim.io import Config
from radiosim.jets import create_jet
from radiosim.mojave import create_mojave
from radiosim.survey import create_survey
from radiosim.utils import (
    add_noise,
    create_grid,
    save_sky_distribution_bundle,
    setup_logger,
)
from radiosim.utils.utils import _save_mojave_bundle

__all__ = ["create_sky_distribution", "simulate_sky_distributions"]

LOGGER = setup_logger(__name__)

OVERALL_PROGRESS = Progress(
    TimeElapsedColumn(),
    TextColumn("{task.description}"),
)

SIMULATION_PROGRESS = Progress(
    TextColumn(
        "[bold blue]Progress for dataset {task.fields[name]}: {task.percentage:.0f}%"
    ),
    BarColumn(),
    TextColumn("({task.completed} of {task.total} steps done)"),
)
PROGRESS_GROUP = Group(
    OVERALL_PROGRESS,
    SIMULATION_PROGRESS,
)


[docs] def simulate_sky_distributions(conf, mode): seed = conf.general.seed if not seed: rng = default_rng() elif isinstance(seed, int): rng = default_rng(seed) bundles = [ conf.dataset.bundles_train, conf.dataset.bundles_valid, conf.dataset.bundles_test, ] # Add task_id for overall progress, i.e. train, valid, # and test dataset tracking for a total of 3. overall_task_id = OVERALL_PROGRESS.add_task("", total=3) with Live(PROGRESS_GROUP): for idx, (opt, num_bundles) in enumerate( zip(["train", "valid", "test"], bundles) ): simulation_task_id = SIMULATION_PROGRESS.add_task( "", total=num_bundles, name=opt ) create_sky_distribution( conf=conf, num_bundles=num_bundles, mode=mode, opt=opt, simulation_task_id=simulation_task_id, rng=rng, ) top_descr = f"[bold #AAAAAA]({idx + 1} out of {3} datasets simulated)" OVERALL_PROGRESS.update(overall_task_id, description=top_descr)
[docs] def create_sky_distribution( conf: Config, num_bundles: int, mode: str, opt: str, simulation_task_id: int, rng=None, ) -> None: for _ in range(num_bundles): path = conf.paths.outpath / f"data_{mode}_{opt}.h5" grid = create_grid(conf.dataset.img_size, conf.dataset.bundle_size) if mode == "jet": sky, target = create_jet(grid, conf) elif mode == "survey": sky, target = create_survey(grid, conf) elif mode == "mojave": data, data_name = create_mojave(conf, rng) sky_bundle = data[0].copy() if conf.dataset.noise: sky_bundle = np.squeeze( add_noise( np.expand_dims(sky_bundle, axis=1), conf.dataset.noise_level, ), axis=1, ) _save_mojave_bundle(path, data=[sky_bundle, *data[1:]], data_name=data_name) SIMULATION_PROGRESS.update(simulation_task_id, advance=1) continue else: raise ValueError( f"Mode {mode} not found. Choose 'survey', " "'jet' or 'mojave' in config file" ) sky_bundle = sky.copy() target_bundle = target.copy() if conf.dataset.noise and conf.dataset.noise_level > 0: sky_bundle = add_noise(sky_bundle, conf.dataset.noise_level) for img in sky_bundle: img -= img.min() img /= img.max() save_sky_distribution_bundle(path, sky_bundle, target_bundle) SIMULATION_PROGRESS.update(simulation_task_id, advance=1)