+
Skip to content

[MNT] - Consolidate duplicated IO functionality #350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,21 @@ Plot related utilies for styling and managing plots.
recursive_plot
save_figure

Input / Output (IO)
-------------------

Save & load related functionality.

.. currentmodule:: specparam.io.models

.. autosummary::
:toctree: generated/

load_model
load_group
load_time
load_event

Utilities
---------

Expand Down Expand Up @@ -407,19 +422,6 @@ Utilities for working with parameters

compute_knee_frequency

Input / Output (IO)
~~~~~~~~~~~~~~~~~~~

.. currentmodule:: specparam.utils.io

.. autosummary::
:toctree: generated/

load_model
load_group_model
load_time_model
load_event_model

Methods Reports
~~~~~~~~~~~~~~~

Expand Down
63 changes: 26 additions & 37 deletions specparam/io/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""File I/O for model objects."""
"""File I/O for model objects.

Notes
-----
Model loader function import model objects locally to prevent circular imports.
"""

import io

Expand Down Expand Up @@ -139,72 +144,62 @@ def save_event(event, file_name, file_path=None, append=False,
save_settings=save_settings, save_data=save_data)


def load_model(file_name, file_path=None, regenerate=True, model=None):
"""Load a SpectralModel object.
def load_model(file_name, file_path=None, regenerate=True):
"""Load a SpectralModel object from file.

Parameters
----------
file_name : str
File(s) to load data from.
file_path : str, optional
File to load the data from.
file_path : Path or str, optional
Path to directory to load from. If None, loads from current directory.
regenerate : bool, optional, default: True
Whether to regenerate the model fit from the loaded data, if data is available.
model : SpectralModel
Loaded model object with data from file.

Returns
-------
model : SpectralModel
Loaded model object with data from file.
"""

# Check for model object, import (avoid circular) and initialize if not
if not model:
from specparam.objs import SpectralModel
model = SpectralModel()

from specparam.objs import SpectralModel
model = SpectralModel()
model.load(file_name, file_path, regenerate)

return model


def load_group(file_name, file_path=None, group=None):
"""Load a SpectralGroupModel object.
def load_group(file_name, file_path=None):
"""Load a SpectralGroupModel object from file.

Parameters
----------
file_name : str
File(s) to load data from.
file_path : str, optional
file_path : Path or str, optional
Path to directory to load from. If None, loads from current directory.
group : SpectralGroupModel
Loaded model object with data from file.

Returns
-------
group : SpectralGroupModel
Loaded model object with data from file.
"""

# Check for model object, import (avoid circular) and initialize if not
if not group:
from specparam.objs import SpectralGroupModel
group = SpectralGroupModel()

from specparam.objs import SpectralGroupModel
group = SpectralGroupModel()
group.load(file_name, file_path)

return group


def load_time(file_name, file_path=None, peak_org=None, time=None):
"""Load a SpectralTimeModel object.
def load_time(file_name, file_path=None, peak_org=None):
"""Load a SpectralTimeModel object from file.

Parameters
----------
file_name : str
File(s) to load data from.
file_path : str, optional
file_path : Path or str, optional
Path to directory to load from. If None, loads from current directory.
peak_org : int or Bands, optional
How to organize peaks.
Expand All @@ -217,24 +212,21 @@ def load_time(file_name, file_path=None, peak_org=None, time=None):
Loaded model object with data from file.
"""

# Check for model object, import (avoid circular) and initialize if not
if not time:
from specparam.objs import SpectralTimeModel
time = SpectralTimeModel()

from specparam.objs import SpectralTimeModel
time = SpectralTimeModel()
time.load(file_name, file_path, peak_org)

return time


def load_event(file_name, file_path=None, peak_org=None, event=None):
"""Load a SpectralTimeEventModel object.
"""Load a SpectralTimeEventModel object from file.

Parameters
----------
file_name : str
File(s) to load data from.
file_path : str, optional
file_path : Path or str, optional
Path to directory to load from. If None, loads from current directory.
peak_org : int or Bands, optional
How to organize peaks.
Expand All @@ -247,11 +239,8 @@ def load_event(file_name, file_path=None, peak_org=None, event=None):
Loaded model object with data from file.
"""

# Check for model object, import (avoid circular) and initialize if not
if not event:
from specparam.objs import SpectralTimeEventModel
event = SpectralTimeEventModel()

from specparam.objs import SpectralTimeEventModel
event = SpectralTimeEventModel()
event.load(file_name, file_path, peak_org)

return event
Expand Down
86 changes: 62 additions & 24 deletions specparam/tests/io/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@

import os

import numpy as np

from specparam.core.items import OBJ_DESC
from specparam.io.files import load_json
from specparam.objs import (SpectralModel, SpectralGroupModel,
SpectralTimeModel, SpectralTimeEventModel)

from specparam.tests.tsettings import TEST_DATA_PATH

Expand Down Expand Up @@ -132,49 +136,83 @@ def test_save_event(tfe):
for ind in range(len(tfe)):
assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json'))

def test_load_file_contents():
"""Check that loaded model files contain the contents they should."""

# Loads file saved from `test_save_model_str`
file_name = 'test_model_all'

loaded_data = load_json(file_name, TEST_DATA_PATH)

for setting in OBJ_DESC['settings']:
assert setting in loaded_data.keys()
for result in OBJ_DESC['results']:
assert result in loaded_data.keys()
for datum in OBJ_DESC['data']:
assert datum in loaded_data.keys()

def test_load_model():

# Loads file saved from `test_save_model_str`
file_name = 'test_model_all'

tmodel = load_model(file_name, TEST_DATA_PATH)
assert tmodel
tfm = load_model(file_name, TEST_DATA_PATH)

assert isinstance(tfm, SpectralModel)

# Check that all elements get loaded
for result in OBJ_DESC['results']:
assert not np.all(np.isnan(getattr(tfm, result)))
for setting in OBJ_DESC['settings']:
assert getattr(tfm, setting) is not None
for data in OBJ_DESC['data']:
assert getattr(tfm, data) is not None
for meta_dat in OBJ_DESC['meta_data']:
assert getattr(tfm, meta_dat) is not None

def test_load_group():

# Loads file saved from `test_save_group`
file_name = 'test_group_all'

tgroup = load_group(file_name, TEST_DATA_PATH)
assert tgroup
tfg = load_group(file_name, TEST_DATA_PATH)

assert isinstance(tfg, SpectralGroupModel)

def test_load_time():
# Check that all elements get loaded
assert len(tfg.group_results) > 0
for setting in OBJ_DESC['settings']:
assert getattr(tfg, setting) is not None
assert tfg.power_spectra is not None
for meta_dat in OBJ_DESC['meta_data']:
assert getattr(tfg, meta_dat) is not None

def test_load_time(tbands):

# Loads file saved from `test_save_time`
file_name = 'test_time_all'

ttime = load_time(file_name, TEST_DATA_PATH)
assert ttime
# Load without bands definition
tft = load_time(file_name, TEST_DATA_PATH)
assert isinstance(tft, SpectralTimeModel)

# Load with bands definition
tft2 = load_time(file_name, TEST_DATA_PATH, tbands)
assert isinstance(tft2, SpectralTimeModel)
assert tft2.time_results

def test_load_event():
def test_load_event(tbands):

# Loads file saved from `test_save_event`
file_name = 'test_event_all'

tevent = load_event(file_name, TEST_DATA_PATH)
assert tevent
# Load without bands definition
tfe = load_event(file_name, TEST_DATA_PATH)
assert isinstance(tfe, SpectralTimeEventModel)
assert len(tfe) > 1

def test_load_file_contents():
"""Check that loaded model files contain the contents they should."""

# Loads file saved from `test_save_model_str`
file_name = 'test_model_all'

loaded_data = load_json(file_name, TEST_DATA_PATH)

for setting in OBJ_DESC['settings']:
assert setting in loaded_data.keys()
for result in OBJ_DESC['results']:
assert result in loaded_data.keys()
for datum in OBJ_DESC['data']:
assert datum in loaded_data.keys()
# Load with bands definition
tfe2 = load_event(file_name, TEST_DATA_PATH, tbands)
assert isinstance(tfe2, SpectralTimeEventModel)
assert tfe2.event_time_results
assert len(tfe2) > 1
75 changes: 0 additions & 75 deletions specparam/tests/utils/test_io.py

This file was deleted.

Loading
点击 这是indexloc提供的php浏览器服务,不要输入任何密码和下载