Robotics
code
Silk_LLM / utilities.py
nlsefouh's picture
Upload 8 files
aac5fad verified
"""Input and output utilities.
The central concept in the I/O infrastructure is a "frame" - a dataclass that represents one atomic
configuration that contains data of different kinds. Atomic units are used in the frame object itself,
unless explicitly stated otherwise. Units specified by the file format are used in the files themselves.
"""
__all__ = [
'AnyPath',
'get_fn_test',
'Frame',
'open_safe',
'working_directory',
'temporary_directory',
'to_file',
'from_file',
'read_frames',
'write_frames',
'merge_frames',
]
import os
import shutil
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Sequence, Union
import numpy as np
from ..utilities import AMLIOError
# functions that are registered to read and write frames
formats = defaultdict(dict)
# mapping of file extensions to file formats
ext2fmt = dict()
AnyPath = Union[str, Path]
def get_fn_test(filename):
"""Get absolute file names of test data.
Arguments:
filename: name of file in the test data directory, no path
"""
fn_out = Path(__file__).parent.parent / '../tests/data' / filename
return fn_out.resolve()
def register_io(fformat: str, operation: str, extension: Union[str, None] = None):
"""Decorator to register an I/O operation for a specific file format.
Optionally, the function can also register a file name extension to automatic
detection of file format from file name.
Arguments:
fformat: name of file format
operation: I/O operation - "read" or "write"
extension: file name extension or `None`
"""
def decorator(function):
if operation not in ('read', 'write'):
raise ValueError('Unrecognized operation. Allowed values: "read", "write".')
formats[fformat][operation] = function
if extension is not None:
formats[fformat]['extension'] = extension
if (extension in ext2fmt.keys()) and ext2fmt[extension] != fformat:
raise ValueError(f'Attempted to register the same file extension ({extension}) twice.')
ext2fmt[extension] = fformat
return decorator
@dataclass(eq=False)
class Frame:
"""All possible data of a single frame.
Used to exchange data between data structure and I/O routines. Defaults are set to `None`, which
corresponds to that given kind of data not being set/available. We do not provide a comparison operator,
at least for now, as comparing NumPy arrays is more involved.
"""
# slots do not work correctly with dataclass
# Here is an alternative: https://pypi.org/project/dataslots/
# Here is some context: https://github.com/ericvsmith/dataclasses/issues/28
# __slots__ = ['names', 'positions', 'cell', 'comment', 'energy', 'forces']
names: Optional[Sequence] = None
positions: Optional[np.ndarray] = None
cell: Optional[np.ndarray] = None
comment: Optional[str] = None
energy: Optional[float] = None
forces: Optional[np.ndarray] = None
def update(self, other: 'Frame', force: bool = False):
"""Update this frame with data from another.
Arguments:
other: another frame
force: whether to overwrite data
"""
# check that we have the same atom names
if (other.names is not None) and (self.names != other.names):
raise ValueError('Inconsistent atom names.')
# take over all that we can
attrs = ['positions', 'cell', 'comment', 'energy', 'forces']
for attr in attrs:
attr_o = getattr(other, attr)
if attr_o is not None:
if force or (getattr(self, attr) is None):
setattr(self, attr, attr_o)
def open_safe(filename, mode='r', buffering=-1, verbose=False):
"""A wrapper around `open` which saves backup files.
If opening for writing and `filename` exists, it will be renamed
so that we do not overwrite any data.
Arguments:
filename: name of file to open
mode: file open mode
buffering: passed through to `open`
verbose: whether to print to standard output what backup was performed
Returns:
an open file
"""
if mode[0] == 'w':
# if writing, make sure file is not overwritten
filename = Path(filename)
i = 0
fn_backup = filename
while fn_backup.exists():
name_new = f'#{filename.name:s}#{i:d}#'
fn_backup = fn_backup.with_name(name_new)
i += 1
if fn_backup != filename:
filename.rename(fn_backup)
if verbose:
print(f'Backup performed: {filename} -> {fn_backup}\n')
elif mode[0] in ('r', 'a'):
# read or append, no danger of overwritten files
pass
else:
# did not expect that, more work needed
raise NotImplementedError(f'Unsupported file open mode: {mode:s}.')
return open(filename, mode, buffering)
@contextmanager
def working_directory(directory):
"""Change working directory within the context.
This is not available in the standard library [1] but can be useful, especially for testing.
The old fixture in pytest (`tmpdir`) used py.path [2] which has `as_cwd`, but this is legacy
code now and not recommended [3].
[1] https://bugs.python.org/issue25625
[2] https://py.readthedocs.io/en/latest/path.html
[3] https://docs.pytest.org/en/latest/how-to/tmpdir.html
Arguments:
directory: directory to change to
"""
# store the current working directory
dir_original = Path().absolute()
# try to change to the new one and then back
try:
os.chdir(directory)
yield
finally:
os.chdir(dir_original)
@contextmanager
def temporary_directory(directory: AnyPath, parents: bool = False, keep: bool = False):
"""Create a temporary directory.
The directory is removed upon exiting the context, unless the users asks to keep it.
Arguments:
directory: directory to create
parents: whether to create parents as well
keep: whether to keep directory after exiting context
"""
directory = Path(directory)
# catch conflict early, a nicer error message
if directory.exists():
raise AMLIOError(f'Unable to create directory, already exists: {directory.absolute()}')
# actually make the directory
directory.mkdir(parents=parents)
# create context, clean up if needed
try:
yield directory
finally:
if not keep:
shutil.rmtree(directory)
def from_file(fn_in, binary=False):
"""Read the contents of a file into a variable.
By default, the file will be read as a text file, resulting in a string.
It `binary` is true, it will be read as a binary file, resulting in bytes.
"""
mode = 'r'
if binary:
mode += 'b'
with open(fn_in, mode) as f_in:
data = f_in.read()
return data
def to_file(data, fn_out, binary=False, verbose=False):
"""Write a variable to a file.
The provided `data` would typically be a string or bytes, if `binary` is true.
The output file name is protected against overwriting and if `verbose is true,
backup file creation will be reported.
"""
mode = 'w'
if binary:
mode += 'b'
with open_safe(fn_out, mode, verbose=verbose) as f_out:
f_out.write(data)
def get_io_operation(fn, fformat, operation):
"""Select I/O function for given file format.
Arguments:
fn: name of file to operate on
fformat: name of file format
operation: I/O operation - "read" or "write"
Returns:
function to read or write one frame
"""
if operation not in ('read', 'write'):
raise ValueError('Unrecognized operation. Allowed values: "read", "write".')
# automatically pick a file format
if fformat is None:
fn = Path(fn)
extension = fn.suffix[1:]
try:
fformat = ext2fmt[extension]
except KeyError:
raise KeyError(f'Extension "{extension:s}" not registered for file format detection.')
try:
return formats[fformat][operation]
except KeyError:
msg = f'File format "{fformat:s}" not supported for operation "{operation:s}".'
raise ValueError(msg)
def read_frames(fn_in, fformat=None, **kwargs):
"""Iterate over a trajectory file, returning all data for each frame."""
read_frame = get_io_operation(fn_in, fformat, 'read')
# read all frames, quit when there is no more data
# File formats read using MDTraj must be opened differently. Maybe there is a more elegany way to do that though
with open(fn_in) as f_in:
while True:
frame = read_frame(f_in, **kwargs)
if frame is None:
break
yield frame
def write_frames(fn_out, frames, fformat=None):
"""Write frames to file.
The format of the file is given by `fformat` or inferred from the file
extension if `fformat` is `None`.
Arguments:
fn_out: name of output file
frames: iterator over `Frame` objects
fformat: format of the file, or `None`
label_prop: label of property to include, or `None`
"""
write_frame = get_io_operation(fn_out, fformat, 'write')
# write all frames to file
with open_safe(fn_out, 'w') as f_out:
for frame in frames:
write_frame(f_out, frame)
def merge_frames(frames, *frames_others, force: bool = False):
"""Merge frames from multiple sources.
The length of the result will be determined by the length of `frames`,
the other iterators should be at least as long as that.
Arguments:
frames: iterator over `Frame` objects
frames_others: more iterators over `Frame` instances
force: whether to overwrite data
Yields:
`Frame` objects
"""
for frame in frames:
for frames_extra in frames_others:
frame.update(next(frames_extra), force=force)
yield frame