Skip to content

Core Experiment module

Core experiment module for experiment control.

This module provides the base classes and functionality for running behavioral experiments. It includes: - State machine implementation for experiment flow control - Condition management and randomization - Trial preparation and execution - Performance tracking and analysis

The module is built around three main classes: - State: Base class for implementing experiment states - StateMachine: Control the flow of the experiment - ExperimentClass: Base class for experiment implementation

ExperimentClass

Parent Experiment.

Source code in src/ethopy/core/experiment.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
class ExperimentClass:
    """Parent Experiment."""

    curr_trial = 0  # the current trial number in the session
    cur_block = 0  # the current block number in the session
    states = {}  # dictionary wiht all states of the experiment
    stims = {}  # dictionary with all stimulus classes
    stim = False  # the current condition stimulus class
    sync = False  # a boolean to make synchronization available
    un_choices = []
    blocks = []
    iter = []
    curr_cond = {}
    block_h = []
    has_responded = False
    resp_ready = False
    required_fields = []
    default_key = {}
    conditions = []
    cond_tables = []
    quit = False
    in_operation = False
    cur_block_sz = 0
    params = None
    logger = None
    setup_conf_idx = 0
    interface = None
    beh = None
    trial_start = 0  # time in ms of the trial start

    def setup(self, logger: Logger, behavior_class, session_params: Dict) -> None:
        """Set up Experiment."""
        self.in_operation = False
        self.conditions = []
        self.iter = []
        self.quit = False
        self.curr_cond = {}
        self.block_h = []
        self.stims = dict()
        self.curr_trial = 0
        self.cur_block_sz = 0

        self.session_params = self.setup_session_params(session_params, self.default_key)

        self.setup_conf_idx = self.session_params["setup_conf_idx"]

        self.logger = logger
        self.beh = behavior_class()
        self.interface = self._interface_setup(
            self.beh, self.logger, self.setup_conf_idx
        )
        self.interface.load_calibration()
        self.beh.setup(self)

        self.logger.log_session(
            self.session_params, experiment_type=self.cond_tables[0], log_task=True
        )

        self.session_timer = Timer()

        np.random.seed(0)  # fix random seed, it can be overidden in the task file

    def setup_session_params(
        self, session_params: Dict[str, Any], default_key: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Set up session parameters with validation.

        Args:
            session_params: Dictionary of session parameters
            default_key: Dictionary of default parameter values
        Returns:
            Dictionary of session parameters with defaults applied
        """
        # Convert dict to SessionParameters for validation and defaults
        params = SessionParameters.from_dict(session_params, default_key)
        params.validate()
        return params.to_dict()

    def _interface_setup(self, beh, logger: Logger, setup_conf_idx: int) -> "Interface":  # noqa: F821
        interface_module = logger.get(
            schema="interface",
            table="SetupConfiguration",
            fields=["interface"],
            key={"setup_conf_idx": setup_conf_idx},
        )[0]
        log.info(f"Interface: {interface_module}")
        interface = getattr(
            import_module(f"ethopy.interfaces.{interface_module}"), interface_module
        )

        return interface(exp=self, beh=beh)

    def start(self) -> None:
        """Start the StateMachine."""
        states = dict()
        for state in self.__class__.__subclasses__():  # Initialize states
            states.update({state().__class__.__name__: state(self)})
        state_control = StateMachine(states)
        self.interface.set_operation_status(True)
        state_control.run()

    def stop(self) -> None:
        """Stop the epxeriment."""
        self.stim.exit()
        self.interface.release()
        self.beh.exit()
        if self.sync:
            while self.interface.is_recording():
                log.info("Waiting for recording to end...")
                time.sleep(1)
        self.logger.closeDatasets()
        self.in_operation = False

    def is_stopped(self) -> None:
        """Check is experiment should stop."""
        self.quit = self.quit or self.logger.setup_status in ["stop", "exit"]
        if self.quit and self.logger.setup_status not in ["stop", "exit"]:
            self.logger.update_setup_info({"status": "stop"})
        if self.quit:
            self.in_operation = False
        return self.quit

    def _stim_init(self, stim_class, stims: Dict) -> Dict:
        # get stimulus class name
        stim_name = stim_class.name()
        if stim_name not in stims:
            stim_class.init(self)
            stims[stim_name] = stim_class
        return stims

    def get_keys_from_dict(self, data: Dict, keys: List) -> Dict:
        """Efficiently extract specific keys from a dictionary.

        Args:
            data (dict): The input dictionary.
            keys (list): The list of keys to extract.

        Returns:
            (dict): A new dictionary with only the specified keys if they exist.

        """
        keys_set = set(keys)  # Convert list to set for O(1) lookup
        return {key: data[key] for key in keys_set.intersection(data)}

    def _get_task_classes(self, stim_class) -> Dict:
        exp_name = {"experiment_class": self.cond_tables[0]}
        beh_name = {
            "behavior_class": self.beh.cond_tables[0]
            if self.beh.cond_tables
            else "None"
        }
        stim_name = {"stimulus_class": stim_class.name()}
        return {**exp_name, **beh_name, **stim_name}

    def make_conditions(
        self,
        stim_class,
        conditions: Dict[str, Any],
        stim_periods: List[str] = None,
    ) -> List[Dict]:
        """Create conditions by combining stimulus, behavior, and experiment."""
        log.debug("-------------- Make conditions --------------")
        self.stims = self._stim_init(stim_class, self.stims)
        used_keys = set()  # all the keys used from dictionary conditions

        # Handle stimulus conditions
        stim_conditions, stim_keys = self._process_stim_conditions(
            stim_class, conditions, stim_periods
        )
        used_keys.update(stim_keys)

        # Process behavior conditions
        beh_conditions, beh_keys = self._process_behavior_conditions(conditions)
        used_keys.update(beh_keys)

        # Process experiment conditions
        exp_conditions, exp_keys = self._process_experiment_conditions(
            self._get_task_classes(stim_class), conditions
        )
        used_keys.update(exp_keys)

        # Combine results and handle unused parameters
        partial_results = [exp_conditions, beh_conditions, stim_conditions]
        unused_conditions = self._handle_unused_parameters(conditions, used_keys)
        if unused_conditions:
            partial_results.append(unused_conditions)
        log.debug("-----------------------------------------------")
        return [
            {k: v for d in comb for k, v in d.items()}
            for comb in product(*partial_results)
        ]

    def _process_stim_conditions(
        self, stim_class, conditions: Dict, stim_periods: List
    ) -> Tuple[List, List]:
        """Process stimulus-specific conditions."""
        if stim_periods:
            period_conditions = {}
            for period in stim_periods:
                stim_dict = self.get_keys_from_dict(
                    conditions[period], get_parameters(stim_class).keys()
                )
                log.debug(
                    f"Stimulus period: {period} use default conditions:"
                    f"\n{get_parameters(stim_class).keys() - stim_dict.keys()}"
                )
                period_conditions[period] = factorize(stim_dict)
                period_conditions[period] = self.stims[
                    stim_class.name()
                ].make_conditions(period_conditions[period])
                for i, stim_condition in enumerate(period_conditions[period]):
                    log.debug(
                        f"Stimulus condition {i}:\n"
                        f"{format_params_print(stim_condition)}"
                    )
            stim_conditions = factorize(period_conditions)
            return stim_conditions, stim_periods

        stim_dict = self.get_keys_from_dict(
            conditions, get_parameters(stim_class).keys()
        )
        log.debug(
            f"Stimulus use default conditions:\n"
            f"{get_parameters(stim_class).keys() - stim_dict.keys()}"
        )
        stim_conditions = factorize(stim_dict)
        stim_conditions = self.stims[stim_class.name()].make_conditions(stim_conditions)
        for i, stim_condition in enumerate(stim_conditions):
            log.debug(f"Stimulus condition {i}:\n{format_params_print(stim_condition)}")
        return stim_conditions, stim_dict.keys()

    def _process_behavior_conditions(self, conditions: Dict) -> Tuple[List, List]:
        """Process behavior-related conditions."""
        beh_dict = self.get_keys_from_dict(conditions, get_parameters(self.beh).keys())
        log.debug(
            f"Behavior use default conditions:\n"
            f"{get_parameters(self.beh).keys() - beh_dict.keys()}"
        )
        beh_conditions = factorize(beh_dict)
        beh_conditions = self.beh.make_conditions(beh_conditions)
        for i, beh_condition in enumerate(beh_conditions):
            log.debug(f"Behavior condition {i}:\n{format_params_print(beh_condition)}")
        return beh_conditions, beh_dict.keys()

    def _process_experiment_conditions(
        self, task_classes: List, conditions: Dict
    ) -> Tuple[List, list]:
        """Process experiment-wide conditions."""
        exp_dict = self.get_keys_from_dict(conditions, get_parameters(self).keys())
        exp_dict.update(task_classes)
        log.debug(
            f"Experiment use default conditions:\n"
            f"{get_parameters(self).keys() - exp_dict.keys()}"
        )
        exp_conditions = factorize(exp_dict)

        for cond in exp_conditions:
            self.validate_condition(cond)
            cond.update({**self.default_key, **self.session_params, **cond})
        cond_tables = ["Condition." + table for table in self.cond_tables]
        conditions_list = self.log_conditions(
            exp_conditions, condition_tables=["Condition"] + cond_tables
        )
        for i, exp_condition in enumerate(exp_conditions):
            log.debug(
                f"Experiment condition {i}:\n{format_params_print(exp_condition)}"
            )
        return conditions_list, exp_dict.keys()

    def _handle_unused_parameters(self, conditions, used_keys) -> Union[List, None]:
        """Process any unused parameters."""
        unused_keys = set(conditions.keys()) - used_keys
        if unused_keys:
            log.warning(
                f"Keys: {unused_keys} are in condition but are not used from "
                f"Experiment, Behavior or Stimulus"
            )
            unused_dict = self.get_keys_from_dict(conditions, unused_keys)
            return factorize(unused_dict)
        return None

    def validate_condition(self, condition: Dict) -> None:
        """Validate a condition dictionary against the required fields.

        Args:
            condition (Dict): The condition dictionary to validate.

        Raises:
            ValueError: If required fields are missing from the condition.

        """
        missing_fields = set(self.required_fields) - set(condition)
        if missing_fields:
            raise ValueError(f"Missing experiment required fields: {missing_fields}")

    def push_conditions(self, conditions: List) -> None:
        """Set the experimental conditions and initializes related data structures.

        This method takes a list of condition dictionaries, prepares data structures
        for tracking choices, blocks, and the current condition.  It also determines
        unique choice hashes based on the response condition and difficulty.

        Args:
            conditions: A list of dictionaries, where each dictionary
                represents an experimental condition.  Each condition
                dictionary is expected to contain at least a "difficulty"
                key.  If a `resp_cond` key (or the default "response_port")
                is present, it's used to create unique choice hashes.

        """
        log.info(f"Number of conditions: {len(conditions)}")
        self.conditions = conditions
        self.blocks = np.array([cond["difficulty"] for cond in self.conditions])
        if np.all(["response_port" in cond for cond in conditions]):
            self.choices = np.array(
                [make_hash([d["response_port"], d["difficulty"]]) for d in conditions]
            )
            self.un_choices, un_idx = np.unique(self.choices, axis=0, return_index=True)
            self.un_blocks = self.blocks[un_idx]
        #  select random condition for first trial initialization
        self.cur_block = min(self.blocks)
        self.curr_cond = np.random.choice(
            [i for (i, v) in zip(self.conditions, self.blocks == self.cur_block) if v]
        )

    def prepare_trial(self) -> None:
        """Prepare trial conditions, stimuli and update trial index."""
        old_cond = self.curr_cond
        self._get_new_cond()

        if not self.curr_cond or self.logger.thread_end.is_set():
            self.quit = True
            return
        if (
            "stimulus_class" not in old_cond
            or self.curr_trial == 0
            or old_cond["stimulus_class"] != self.curr_cond["stimulus_class"]
        ):
            if "stimulus_class" in old_cond and self.curr_trial != 0:
                self.stim.exit()
            self.stim = self.stims[self.curr_cond["stimulus_class"]]
            log.debug("setting up stimulus")
            self.stim.setup()
            log.debug("stimuli is done")
        self.curr_trial += 1
        self.logger.update_trial_idx(self.curr_trial)
        self.trial_start = self.logger.logger_timer.elapsed_time()
        self.logger.log(
            "Trial",
            dict(cond_hash=self.curr_cond["cond_hash"], time=self.trial_start),
            priority=3,
        )
        if not self.in_operation:
            self.in_operation = True

    def name(self) -> str:
        """Name of experiment class."""
        return type(self).__name__

    def _make_cond_hash(
        self,
        conditions: List[Dict],
        hash_field: str,
        schema: dj.schema,
        condition_tables: List,
    ) -> List[Dict]:
        """Make unique hash based on all fields from condition tables."""
        # get all fields from condition tables except hash
        fields_key = {
            key
            for ctable in condition_tables
            for key in self.logger.get_table_keys(schema, ctable)
        }
        fields_key.discard(hash_field)
        for condition in conditions:
            # find all dependant fields and generate hash
            key = {k: condition[k] for k in fields_key if k in condition}
            condition.update({hash_field: make_hash(key)})
        return conditions

    def log_conditions(
        self,
        conditions,
        condition_tables=None,
        schema="experiment",
        hash_field="cond_hash",
        priority=2,
    ) -> List[Dict]:
        """Log experimental conditions to specified tables with hashes tracking.

        Args:
            conditions (List): List of condition dictionaries to log
            condition_tables (List): List of table names to log to
            schema (db.shcema): Database schema name
            hash_field (str): Name of the hash field
            priority (int): for the insertion order of the logger

        Returns:
            List of processed conditions with added hashes

        """
        if not conditions:
            return []

        if condition_tables is None:
            condition_tables = ["Condition"]

        conditions = self._make_cond_hash(
            conditions, hash_field, schema, condition_tables
        )

        processed_conditions = conditions.copy()
        for condition in processed_conditions:
            _priority = priority
            # insert conditions fields to the correspond table
            for ctable in condition_tables:
                # Get table metadata
                fields = set(self.logger.get_table_keys(schema, ctable))
                primary_keys = set(
                    self.logger.get_table_keys(schema, ctable, key_type="primary")
                )
                core = [key for key in primary_keys if key != hash_field]

                # Validate condition has all required fields
                missing_keys = set(fields) - set(condition.keys())
                if missing_keys:
                    log.warning(f"Skipping {ctable}, Missing keys:{missing_keys}")
                    continue

                # check if there is a primary key which is not hash and it is iterable
                if core and hasattr(condition[core[0]], "__iter__"):
                    # TODO make a function for this and clarify it
                    # If any of the primary keys is iterable all the rest should be.
                    # The first element of the iterable will be matched with the first
                    # element of the rest of the keys
                    for idx, _ in enumerate(condition[core[0]]):
                        cond_key = {}
                        for k in fields:
                            if isinstance(condition[k], (int, float, str)):
                                cond_key[k] = condition[k]
                            else:
                                cond_key[k] = condition[k][idx]

                        self.logger.put(
                            table=ctable,
                            tuple=cond_key,
                            schema=schema,
                            priority=_priority,
                        )

                else:
                    self.logger.put(
                        table=ctable, tuple=condition, schema=schema, priority=_priority
                    )

                # Increment the priority for each subsequent table
                # to ensure they are inserted in the correct order
                _priority += 1

        return processed_conditions

    def _anti_bias(self, choice_h, un_choices):
        choice_h = np.array(
            [make_hash(c) for c in choice_h[-self.curr_cond["bias_window"] :]]
        )
        if len(choice_h) < self.curr_cond["bias_window"]:
            choice_h = self.choices
        fixed_p = 1 - np.array([np.mean(choice_h == un) for un in un_choices])
        if sum(fixed_p) == 0:
            fixed_p = np.ones(np.shape(fixed_p))
        return np.random.choice(un_choices, 1, p=fixed_p / sum(fixed_p))

    def _get_new_cond(self) -> None:
        """Get next condition based on trial selection method."""
        selection_method = self.curr_cond["trial_selection"]
        selection_handlers = {
            "fixed": self._fixed_selection,
            "block": self._block_selection,
            "random": self._random_selection,
            "staircase": self._staircase_selection,
            "biased": self._biased_selection,
        }

        handler = selection_handlers.get(selection_method)
        if handler:
            self.curr_cond = handler()
        else:
            log.error(f"Selection method '{selection_method}' not implemented!")
            self.quit = True

    def _fixed_selection(self) -> Dict:
        """Select next condition by popping from ordered list."""
        return [] if len(self.conditions) == 0 else self.conditions.pop(0)

    def _block_selection(self) -> Dict:
        """Select random condition from a block.

        Select a condition from a block of conditions until all
        conditions has been selected, then repeat them randomnly.
        """
        if np.size(self.iter) == 0:
            self.iter = np.random.permutation(np.size(self.conditions))
        cond = self.conditions[self.iter[0]]
        self.iter = self.iter[1:]
        return cond

    def _random_selection(self) -> Dict:
        """Select random condition from available conditions."""
        return np.random.choice(self.conditions)

    def _update_block_difficulty(self, perf: float) -> None:
        """Update block difficulty based on performance.

        Args:
            perf: Current performance metric

        """
        if self.cur_block_sz >= self.curr_cond["staircase_window"]:
            if perf >= self.curr_cond["stair_up"]:
                self.cur_block = self.curr_cond["next_up"]
                self.cur_block_sz = 0
                self.logger.update_setup_info({"difficulty": self.cur_block})
            elif perf < self.curr_cond["stair_down"]:
                self.cur_block = self.curr_cond["next_down"]
                self.cur_block_sz = 0
                self.logger.update_setup_info({"difficulty": self.cur_block})

    def _get_valid_conditions(self, condition_idx: np.ndarray) -> List[Dict]:
        """Get list of valid conditions based on condition index.

        Args:
            condition_idx: Boolean array indicating valid conditions

        Returns:
            List of valid condition dictionaries

        """
        return [c for c, v in zip(self.conditions, condition_idx) if v]

    def _staircase_selection(self) -> Dict:
        """Select next condition using staircase method."""
        # Get performance metrics
        perf, choice_h = self._get_performance()

        # Update block size if there was a choice in last trial
        if np.size(self.beh.choice_history) and self.beh.choice_history[-1:][0] > 0:
            self.cur_block_sz += 1

        # Update difficulty if needed
        self._update_block_difficulty(perf)

        # Select condition based on current block and anti-bias
        if self.curr_cond["antibias"]:
            valid_choices = self.un_choices[self.un_blocks == self.cur_block]
            anti_bias = self._anti_bias(choice_h, valid_choices)
            condition_idx = np.logical_and(
                self.choices == anti_bias, self.blocks == self.cur_block
            )
        else:
            condition_idx = self.blocks == self.cur_block

        valid_conditions = self._get_valid_conditions(condition_idx)
        self.block_h.append(self.cur_block)
        return np.random.choice(valid_conditions)

    def _biased_selection(self) -> Dict:
        """Select next condition using anti-bias method."""
        perf, choice_h = self._get_performance()
        anti_bias = self._anti_bias(choice_h, self.un_choices)
        condition_idx = self.choices == anti_bias
        valid_conditions = self._get_valid_conditions(condition_idx)
        return np.random.choice(valid_conditions)

    def add_selection_method(self, name: str, handler: Callable[[], Dict]) -> None:
        """Add a new trial selection method.

        Args:
            name: Name of the selection method
            handler: Function that returns next condition

        Example:
            def my_selection_method(self):
                # Custom selection logic
                return selected_condition

            experiment.add_selection_method('custom', my_selection_method)

        """
        if not hasattr(self, f"_{name}_selection"):
            setattr(self, f"_{name}_selection", handler)
            log.info(f"Added new selection method: {name}")
        else:
            log.warning(f"Selection method '{name}' already exists")

    def _get_performance(self) -> Tuple[float, List[List[int]]]:
        """Calculate performance metrics based on trial history."""
        rewards, choices, blocks = self._extract_valid_trial_data()

        if not rewards.size:  # Check if there are any valid trials
            return np.nan, []

        window = self.curr_cond["staircase_window"]
        recent_rewards = rewards[-window:]
        recent_choices = choices[-window:]
        recent_blocks = blocks[-window:] if blocks is not None else None

        performance = self._calculate_metric(
            recent_rewards, recent_choices, recent_blocks
        )

        choice_history = self._get_choice_history(choices, blocks)

        log.debug(
            f"\nstaircase_window: {window},\n"
            f"rewards: {recent_rewards},\n"
            f"choices: {recent_choices},\n"
            f"blocks: {recent_blocks},\n"
            f"performace: {performance}"
        )

        return performance, choice_history

    def _extract_valid_trial_data(
        self,
    ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray]]:
        """Extract trials that are either punish or reward.

        rewards: reward amount given at trials that have been rewarded else nan
        choices: selected port in reward & punish trials
        blocks: block index in each trial that is reward or punish

        """
        valid_idx = np.logical_or(
            ~np.isnan(self.beh.reward_history), ~np.isnan(self.beh.punish_history)
        )

        rewards = np.asarray(self.beh.reward_history)[valid_idx]
        choices = np.int64(np.asarray(self.beh.choice_history)[valid_idx])
        blocks = np.asarray(self.block_h)[valid_idx] if self.block_h else None

        return rewards, choices, blocks

    def _calculate_accuracy(
        self, rewards: np.ndarray, choices: np.ndarray, blocks: Optional[np.ndarray]
    ) -> float:
        """Calculate accuracy from trial data."""
        return np.nanmean(np.greater(rewards, 0))

    def _calculate_dprime(
        self, rewards: np.ndarray, choices: np.ndarray, blocks: Optional[np.ndarray]
    ) -> float:
        """Calculate d-prime from trial data."""
        y_true = [c if r > 0 else c % 2 + 1 for (c, r) in zip(choices, rewards)]

        if len(np.unique(y_true)) > 1:
            auc = roc_auc_score(y_true, choices)
            return np.sqrt(2) * stats.norm.ppf(auc)

        return np.nan

    def _calculate_metric(
        self, rewards: np.ndarray, choices: np.ndarray, blocks: Optional[np.ndarray]
    ) -> float:
        """Calculate performance metric specified in current condition."""
        metric_handlers = {
            "accuracy": self._calculate_accuracy,
            "dprime": self._calculate_dprime,
        }

        handler = metric_handlers.get(self.curr_cond["metric"])
        if handler:
            return handler(rewards, choices, blocks)
        else:
            log.error(
                f"Performance metric '{self.curr_cond['metric']}' not implemented!"
            )
            self.quit = True
            return np.nan

    def _get_choice_history(
        self, choices: np.ndarray, blocks: Optional[np.ndarray]
    ) -> List[List[int]]:
        """Create choice history with difficulty levels."""
        if blocks is not None:
            return [[c, d] for c, d in zip(choices, blocks)]
        return [[c, 0] for c in choices]

    def add_performance_metric(
        self,
        name: str,
        handler: Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float],
    ) -> None:
        """Add a new performance metric calculation method.

        Args:
            name: Name of the metric
            handler: Function that takes ValidTrials and returns performance score

        Example:
            def calculate_custom_metric(trials):
                # Custom metric calculation
                return score

            experiment.add_performance_metric('custom', calculate_custom_metric)

        """
        if not hasattr(self, f"{name}"):
            setattr(self, f"{name}", handler)
            log.info(f"Added new performance metric: {name}")
        else:
            log.warning(f"Performance metric '{name}' already exists")

    @dataclass
    class Block:
        """A class representing a block of trials in an experiment.

        Args:
            difficulty (int): The difficulty level of the block. Default is 0.
            stair_up (float): Threshold given to compare if performance is higher in
                order to go to the next_up difficulty. Default is 0.7.
            stair_down (float): Threshold given to compare if performance is smaller in
                order to go to the next_down difficulty. Default is 0.7.
            next_up (int): The difficulty level to go to if perf>stair_up. Default is 0.
            next_down (int): The difficulty level to go to if perf<stair_down.
                Default is 0.
            staircase_window (int): The window size for the staircase procedure.
                Default is 20.
            bias_window (int): The window size for bias correction. Default is 5.
            trial_selection (str): The method for selecting trials. Default is "fixed".
            metric (str): The metric used for evaluating performance. Default is
                "accuracy".
            antibias (bool): Whether to apply antibias correction. Default is True.
        """

        difficulty: int = field(compare=True, default=0, hash=True)
        stair_up: float = field(compare=False, default=0.7)
        stair_down: float = field(compare=False, default=0.55)
        next_up: int = field(compare=False, default=0)
        next_down: int = field(compare=False, default=0)
        staircase_window: int = field(compare=False, default=20)
        bias_window: int = field(compare=False, default=5)
        trial_selection: str = field(compare=False, default="fixed")
        metric: str = field(compare=False, default="accuracy")
        antibias: bool = field(compare=False, default=True)

        def dict(self) -> Dict:
            """Rerurn parameters as dictionary."""
            return self.__dict__

Block dataclass

A class representing a block of trials in an experiment.

Parameters:

Name Type Description Default
difficulty int

The difficulty level of the block. Default is 0.

0
stair_up float

Threshold given to compare if performance is higher in order to go to the next_up difficulty. Default is 0.7.

0.7
stair_down float

Threshold given to compare if performance is smaller in order to go to the next_down difficulty. Default is 0.7.

0.55
next_up int

The difficulty level to go to if perf>stair_up. Default is 0.

0
next_down int

The difficulty level to go to if perf<stair_down. Default is 0.

0
staircase_window int

The window size for the staircase procedure. Default is 20.

20
bias_window int

The window size for bias correction. Default is 5.

5
trial_selection str

The method for selecting trials. Default is "fixed".

'fixed'
metric str

The metric used for evaluating performance. Default is "accuracy".

'accuracy'
antibias bool

Whether to apply antibias correction. Default is True.

True
Source code in src/ethopy/core/experiment.py
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
@dataclass
class Block:
    """A class representing a block of trials in an experiment.

    Args:
        difficulty (int): The difficulty level of the block. Default is 0.
        stair_up (float): Threshold given to compare if performance is higher in
            order to go to the next_up difficulty. Default is 0.7.
        stair_down (float): Threshold given to compare if performance is smaller in
            order to go to the next_down difficulty. Default is 0.7.
        next_up (int): The difficulty level to go to if perf>stair_up. Default is 0.
        next_down (int): The difficulty level to go to if perf<stair_down.
            Default is 0.
        staircase_window (int): The window size for the staircase procedure.
            Default is 20.
        bias_window (int): The window size for bias correction. Default is 5.
        trial_selection (str): The method for selecting trials. Default is "fixed".
        metric (str): The metric used for evaluating performance. Default is
            "accuracy".
        antibias (bool): Whether to apply antibias correction. Default is True.
    """

    difficulty: int = field(compare=True, default=0, hash=True)
    stair_up: float = field(compare=False, default=0.7)
    stair_down: float = field(compare=False, default=0.55)
    next_up: int = field(compare=False, default=0)
    next_down: int = field(compare=False, default=0)
    staircase_window: int = field(compare=False, default=20)
    bias_window: int = field(compare=False, default=5)
    trial_selection: str = field(compare=False, default="fixed")
    metric: str = field(compare=False, default="accuracy")
    antibias: bool = field(compare=False, default=True)

    def dict(self) -> Dict:
        """Rerurn parameters as dictionary."""
        return self.__dict__

dict()

Rerurn parameters as dictionary.

Source code in src/ethopy/core/experiment.py
905
906
907
def dict(self) -> Dict:
    """Rerurn parameters as dictionary."""
    return self.__dict__

add_performance_metric(name, handler)

Add a new performance metric calculation method.

Parameters:

Name Type Description Default
name str

Name of the metric

required
handler Callable[[ndarray, ndarray, Optional[ndarray]], float]

Function that takes ValidTrials and returns performance score

required
Example

def calculate_custom_metric(trials): # Custom metric calculation return score

experiment.add_performance_metric('custom', calculate_custom_metric)

Source code in src/ethopy/core/experiment.py
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
def add_performance_metric(
    self,
    name: str,
    handler: Callable[[np.ndarray, np.ndarray, Optional[np.ndarray]], float],
) -> None:
    """Add a new performance metric calculation method.

    Args:
        name: Name of the metric
        handler: Function that takes ValidTrials and returns performance score

    Example:
        def calculate_custom_metric(trials):
            # Custom metric calculation
            return score

        experiment.add_performance_metric('custom', calculate_custom_metric)

    """
    if not hasattr(self, f"{name}"):
        setattr(self, f"{name}", handler)
        log.info(f"Added new performance metric: {name}")
    else:
        log.warning(f"Performance metric '{name}' already exists")

add_selection_method(name, handler)

Add a new trial selection method.

Parameters:

Name Type Description Default
name str

Name of the selection method

required
handler Callable[[], Dict]

Function that returns next condition

required
Example

def my_selection_method(self): # Custom selection logic return selected_condition

experiment.add_selection_method('custom', my_selection_method)

Source code in src/ethopy/core/experiment.py
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
def add_selection_method(self, name: str, handler: Callable[[], Dict]) -> None:
    """Add a new trial selection method.

    Args:
        name: Name of the selection method
        handler: Function that returns next condition

    Example:
        def my_selection_method(self):
            # Custom selection logic
            return selected_condition

        experiment.add_selection_method('custom', my_selection_method)

    """
    if not hasattr(self, f"_{name}_selection"):
        setattr(self, f"_{name}_selection", handler)
        log.info(f"Added new selection method: {name}")
    else:
        log.warning(f"Selection method '{name}' already exists")

get_keys_from_dict(data, keys)

Efficiently extract specific keys from a dictionary.

Parameters:

Name Type Description Default
data dict

The input dictionary.

required
keys list

The list of keys to extract.

required

Returns:

Type Description
dict

A new dictionary with only the specified keys if they exist.

Source code in src/ethopy/core/experiment.py
290
291
292
293
294
295
296
297
298
299
300
301
302
def get_keys_from_dict(self, data: Dict, keys: List) -> Dict:
    """Efficiently extract specific keys from a dictionary.

    Args:
        data (dict): The input dictionary.
        keys (list): The list of keys to extract.

    Returns:
        (dict): A new dictionary with only the specified keys if they exist.

    """
    keys_set = set(keys)  # Convert list to set for O(1) lookup
    return {key: data[key] for key in keys_set.intersection(data)}

is_stopped()

Check is experiment should stop.

Source code in src/ethopy/core/experiment.py
273
274
275
276
277
278
279
280
def is_stopped(self) -> None:
    """Check is experiment should stop."""
    self.quit = self.quit or self.logger.setup_status in ["stop", "exit"]
    if self.quit and self.logger.setup_status not in ["stop", "exit"]:
        self.logger.update_setup_info({"status": "stop"})
    if self.quit:
        self.in_operation = False
    return self.quit

log_conditions(conditions, condition_tables=None, schema='experiment', hash_field='cond_hash', priority=2)

Log experimental conditions to specified tables with hashes tracking.

Parameters:

Name Type Description Default
conditions List

List of condition dictionaries to log

required
condition_tables List

List of table names to log to

None
schema shcema

Database schema name

'experiment'
hash_field str

Name of the hash field

'cond_hash'
priority int

for the insertion order of the logger

2

Returns:

Type Description
List[Dict]

List of processed conditions with added hashes

Source code in src/ethopy/core/experiment.py
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def log_conditions(
    self,
    conditions,
    condition_tables=None,
    schema="experiment",
    hash_field="cond_hash",
    priority=2,
) -> List[Dict]:
    """Log experimental conditions to specified tables with hashes tracking.

    Args:
        conditions (List): List of condition dictionaries to log
        condition_tables (List): List of table names to log to
        schema (db.shcema): Database schema name
        hash_field (str): Name of the hash field
        priority (int): for the insertion order of the logger

    Returns:
        List of processed conditions with added hashes

    """
    if not conditions:
        return []

    if condition_tables is None:
        condition_tables = ["Condition"]

    conditions = self._make_cond_hash(
        conditions, hash_field, schema, condition_tables
    )

    processed_conditions = conditions.copy()
    for condition in processed_conditions:
        _priority = priority
        # insert conditions fields to the correspond table
        for ctable in condition_tables:
            # Get table metadata
            fields = set(self.logger.get_table_keys(schema, ctable))
            primary_keys = set(
                self.logger.get_table_keys(schema, ctable, key_type="primary")
            )
            core = [key for key in primary_keys if key != hash_field]

            # Validate condition has all required fields
            missing_keys = set(fields) - set(condition.keys())
            if missing_keys:
                log.warning(f"Skipping {ctable}, Missing keys:{missing_keys}")
                continue

            # check if there is a primary key which is not hash and it is iterable
            if core and hasattr(condition[core[0]], "__iter__"):
                # TODO make a function for this and clarify it
                # If any of the primary keys is iterable all the rest should be.
                # The first element of the iterable will be matched with the first
                # element of the rest of the keys
                for idx, _ in enumerate(condition[core[0]]):
                    cond_key = {}
                    for k in fields:
                        if isinstance(condition[k], (int, float, str)):
                            cond_key[k] = condition[k]
                        else:
                            cond_key[k] = condition[k][idx]

                    self.logger.put(
                        table=ctable,
                        tuple=cond_key,
                        schema=schema,
                        priority=_priority,
                    )

            else:
                self.logger.put(
                    table=ctable, tuple=condition, schema=schema, priority=_priority
                )

            # Increment the priority for each subsequent table
            # to ensure they are inserted in the correct order
            _priority += 1

    return processed_conditions

make_conditions(stim_class, conditions, stim_periods=None)

Create conditions by combining stimulus, behavior, and experiment.

Source code in src/ethopy/core/experiment.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def make_conditions(
    self,
    stim_class,
    conditions: Dict[str, Any],
    stim_periods: List[str] = None,
) -> List[Dict]:
    """Create conditions by combining stimulus, behavior, and experiment."""
    log.debug("-------------- Make conditions --------------")
    self.stims = self._stim_init(stim_class, self.stims)
    used_keys = set()  # all the keys used from dictionary conditions

    # Handle stimulus conditions
    stim_conditions, stim_keys = self._process_stim_conditions(
        stim_class, conditions, stim_periods
    )
    used_keys.update(stim_keys)

    # Process behavior conditions
    beh_conditions, beh_keys = self._process_behavior_conditions(conditions)
    used_keys.update(beh_keys)

    # Process experiment conditions
    exp_conditions, exp_keys = self._process_experiment_conditions(
        self._get_task_classes(stim_class), conditions
    )
    used_keys.update(exp_keys)

    # Combine results and handle unused parameters
    partial_results = [exp_conditions, beh_conditions, stim_conditions]
    unused_conditions = self._handle_unused_parameters(conditions, used_keys)
    if unused_conditions:
        partial_results.append(unused_conditions)
    log.debug("-----------------------------------------------")
    return [
        {k: v for d in comb for k, v in d.items()}
        for comb in product(*partial_results)
    ]

name()

Name of experiment class.

Source code in src/ethopy/core/experiment.py
515
516
517
def name(self) -> str:
    """Name of experiment class."""
    return type(self).__name__

prepare_trial()

Prepare trial conditions, stimuli and update trial index.

Source code in src/ethopy/core/experiment.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def prepare_trial(self) -> None:
    """Prepare trial conditions, stimuli and update trial index."""
    old_cond = self.curr_cond
    self._get_new_cond()

    if not self.curr_cond or self.logger.thread_end.is_set():
        self.quit = True
        return
    if (
        "stimulus_class" not in old_cond
        or self.curr_trial == 0
        or old_cond["stimulus_class"] != self.curr_cond["stimulus_class"]
    ):
        if "stimulus_class" in old_cond and self.curr_trial != 0:
            self.stim.exit()
        self.stim = self.stims[self.curr_cond["stimulus_class"]]
        log.debug("setting up stimulus")
        self.stim.setup()
        log.debug("stimuli is done")
    self.curr_trial += 1
    self.logger.update_trial_idx(self.curr_trial)
    self.trial_start = self.logger.logger_timer.elapsed_time()
    self.logger.log(
        "Trial",
        dict(cond_hash=self.curr_cond["cond_hash"], time=self.trial_start),
        priority=3,
    )
    if not self.in_operation:
        self.in_operation = True

push_conditions(conditions)

Set the experimental conditions and initializes related data structures.

This method takes a list of condition dictionaries, prepares data structures for tracking choices, blocks, and the current condition. It also determines unique choice hashes based on the response condition and difficulty.

Parameters:

Name Type Description Default
conditions List

A list of dictionaries, where each dictionary represents an experimental condition. Each condition dictionary is expected to contain at least a "difficulty" key. If a resp_cond key (or the default "response_port") is present, it's used to create unique choice hashes.

required
Source code in src/ethopy/core/experiment.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
def push_conditions(self, conditions: List) -> None:
    """Set the experimental conditions and initializes related data structures.

    This method takes a list of condition dictionaries, prepares data structures
    for tracking choices, blocks, and the current condition.  It also determines
    unique choice hashes based on the response condition and difficulty.

    Args:
        conditions: A list of dictionaries, where each dictionary
            represents an experimental condition.  Each condition
            dictionary is expected to contain at least a "difficulty"
            key.  If a `resp_cond` key (or the default "response_port")
            is present, it's used to create unique choice hashes.

    """
    log.info(f"Number of conditions: {len(conditions)}")
    self.conditions = conditions
    self.blocks = np.array([cond["difficulty"] for cond in self.conditions])
    if np.all(["response_port" in cond for cond in conditions]):
        self.choices = np.array(
            [make_hash([d["response_port"], d["difficulty"]]) for d in conditions]
        )
        self.un_choices, un_idx = np.unique(self.choices, axis=0, return_index=True)
        self.un_blocks = self.blocks[un_idx]
    #  select random condition for first trial initialization
    self.cur_block = min(self.blocks)
    self.curr_cond = np.random.choice(
        [i for (i, v) in zip(self.conditions, self.blocks == self.cur_block) if v]
    )

setup(logger, behavior_class, session_params)

Set up Experiment.

Source code in src/ethopy/core/experiment.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def setup(self, logger: Logger, behavior_class, session_params: Dict) -> None:
    """Set up Experiment."""
    self.in_operation = False
    self.conditions = []
    self.iter = []
    self.quit = False
    self.curr_cond = {}
    self.block_h = []
    self.stims = dict()
    self.curr_trial = 0
    self.cur_block_sz = 0

    self.session_params = self.setup_session_params(session_params, self.default_key)

    self.setup_conf_idx = self.session_params["setup_conf_idx"]

    self.logger = logger
    self.beh = behavior_class()
    self.interface = self._interface_setup(
        self.beh, self.logger, self.setup_conf_idx
    )
    self.interface.load_calibration()
    self.beh.setup(self)

    self.logger.log_session(
        self.session_params, experiment_type=self.cond_tables[0], log_task=True
    )

    self.session_timer = Timer()

    np.random.seed(0)  # fix random seed, it can be overidden in the task file

setup_session_params(session_params, default_key)

Set up session parameters with validation.

Parameters:

Name Type Description Default
session_params Dict[str, Any]

Dictionary of session parameters

required
default_key Dict[str, Any]

Dictionary of default parameter values

required
Source code in src/ethopy/core/experiment.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def setup_session_params(
    self, session_params: Dict[str, Any], default_key: Dict[str, Any]
) -> Dict[str, Any]:
    """Set up session parameters with validation.

    Args:
        session_params: Dictionary of session parameters
        default_key: Dictionary of default parameter values
    Returns:
        Dictionary of session parameters with defaults applied
    """
    # Convert dict to SessionParameters for validation and defaults
    params = SessionParameters.from_dict(session_params, default_key)
    params.validate()
    return params.to_dict()

start()

Start the StateMachine.

Source code in src/ethopy/core/experiment.py
252
253
254
255
256
257
258
259
def start(self) -> None:
    """Start the StateMachine."""
    states = dict()
    for state in self.__class__.__subclasses__():  # Initialize states
        states.update({state().__class__.__name__: state(self)})
    state_control = StateMachine(states)
    self.interface.set_operation_status(True)
    state_control.run()

stop()

Stop the epxeriment.

Source code in src/ethopy/core/experiment.py
261
262
263
264
265
266
267
268
269
270
271
def stop(self) -> None:
    """Stop the epxeriment."""
    self.stim.exit()
    self.interface.release()
    self.beh.exit()
    if self.sync:
        while self.interface.is_recording():
            log.info("Waiting for recording to end...")
            time.sleep(1)
    self.logger.closeDatasets()
    self.in_operation = False

validate_condition(condition)

Validate a condition dictionary against the required fields.

Parameters:

Name Type Description Default
condition Dict

The condition dictionary to validate.

required

Raises:

Type Description
ValueError

If required fields are missing from the condition.

Source code in src/ethopy/core/experiment.py
441
442
443
444
445
446
447
448
449
450
451
452
453
def validate_condition(self, condition: Dict) -> None:
    """Validate a condition dictionary against the required fields.

    Args:
        condition (Dict): The condition dictionary to validate.

    Raises:
        ValueError: If required fields are missing from the condition.

    """
    missing_fields = set(self.required_fields) - set(condition)
    if missing_fields:
        raise ValueError(f"Missing experiment required fields: {missing_fields}")

SessionParameters dataclass

Internal class for managing and validating session-wide parameters.

This class handles all parameters that apply to the entire experimental session, as opposed to parameters that vary between trials/conditions.

Attributes:

Name Type Description
setup_conf_idx int

Index for setup configuration (defaults to 0)

user_name str

Name of user running the experiment (defaults to "bot")

start_time str

Session start time in "HH:MM:SS" format (defaults to empty string)

stop_time str

Session stop time in "HH:MM:SS" format (defaults to empty string)

max_reward float

Maximum total reward allowed in session

min_reward float

Minimum reward per trial

hydrate_delay int

Delay between hydration rewards in ms

noresponse_intertrial bool

Whether to have intertrial period on no response

bias_window int

Window size for bias correction

Source code in src/ethopy/core/experiment.py
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
@dataclass
class SessionParameters:
    """Internal class for managing and validating session-wide parameters.

    This class handles all parameters that apply to the entire experimental session,
    as opposed to parameters that vary between trials/conditions.

    Attributes:
        setup_conf_idx (int): Index for setup configuration (defaults to 0)
        user_name (str): Name of user running the experiment (defaults to "bot")
        start_time (str): Session start time in "HH:MM:SS" format (defaults to empty string)
        stop_time (str): Session stop time in "HH:MM:SS" format (defaults to empty string)
        max_reward (float): Maximum total reward allowed in session
        min_reward (float): Minimum reward per trial
        hydrate_delay (int): Delay between hydration rewards in ms
        noresponse_intertrial (bool): Whether to have intertrial period on no response
        bias_window (int): Window size for bias correction
    """
    max_reward: float
    min_reward: float
    hydrate_delay: int
    setup_conf_idx: int = 0  # Default value for setup configuration
    user_name: str = "bot"
    start_time: str = ""
    stop_time: str = ""

    @classmethod
    def from_dict(
        cls, params: Dict[str, Any], default_key: Dict[str, Any]
    ) -> "SessionParameters":
        """Create parameters from a dictionary, using defaults for missing values.

        Args:
            params: Dictionary of session parameters
            default_key: Dictionary of default parameter values
        Returns:
            SessionParameters instance with merged parameters
        """
        # Only use keys that exist in the dataclass
        valid_keys = set(cls.__annotations__.keys())
        invalid_keys = set(params.keys()) - valid_keys
        if invalid_keys:
            log.warning(f"Not used session parameters: {invalid_keys}")

        # Get valid parameters from both sources
        filtered_params = {}
        for key in valid_keys:
            if key in params:
                filtered_params[key] = params[key]
            elif key in default_key:
                filtered_params[key] = default_key[key]

        return cls(**filtered_params)

    def validate(self) -> None:
        """Validate parameters."""
        if self.start_time and not self.stop_time:
            raise ValueError(
                "If 'start_time' is defined, 'stop_time' must also be defined"
            )

        if self.start_time:
            try:
                datetime.strptime(self.start_time, "%H:%M:%S")
                datetime.strptime(self.stop_time, "%H:%M:%S")
            except ValueError:
                raise ValueError("Time must be in 'HH:MM:SS' format")

    def to_dict(self) -> Dict[str, Any]:
        """Convert parameters to dictionary format."""
        return {k: v for k, v in self.__dict__.items() if v is not None}

from_dict(params, default_key) classmethod

Create parameters from a dictionary, using defaults for missing values.

Parameters:

Name Type Description Default
params Dict[str, Any]

Dictionary of session parameters

required
default_key Dict[str, Any]

Dictionary of default parameter values

required
Source code in src/ethopy/core/experiment.py
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
@classmethod
def from_dict(
    cls, params: Dict[str, Any], default_key: Dict[str, Any]
) -> "SessionParameters":
    """Create parameters from a dictionary, using defaults for missing values.

    Args:
        params: Dictionary of session parameters
        default_key: Dictionary of default parameter values
    Returns:
        SessionParameters instance with merged parameters
    """
    # Only use keys that exist in the dataclass
    valid_keys = set(cls.__annotations__.keys())
    invalid_keys = set(params.keys()) - valid_keys
    if invalid_keys:
        log.warning(f"Not used session parameters: {invalid_keys}")

    # Get valid parameters from both sources
    filtered_params = {}
    for key in valid_keys:
        if key in params:
            filtered_params[key] = params[key]
        elif key in default_key:
            filtered_params[key] = default_key[key]

    return cls(**filtered_params)

to_dict()

Convert parameters to dictionary format.

Source code in src/ethopy/core/experiment.py
978
979
980
def to_dict(self) -> Dict[str, Any]:
    """Convert parameters to dictionary format."""
    return {k: v for k, v in self.__dict__.items() if v is not None}

validate()

Validate parameters.

Source code in src/ethopy/core/experiment.py
964
965
966
967
968
969
970
971
972
973
974
975
976
def validate(self) -> None:
    """Validate parameters."""
    if self.start_time and not self.stop_time:
        raise ValueError(
            "If 'start_time' is defined, 'stop_time' must also be defined"
        )

    if self.start_time:
        try:
            datetime.strptime(self.start_time, "%H:%M:%S")
            datetime.strptime(self.stop_time, "%H:%M:%S")
        except ValueError:
            raise ValueError("Time must be in 'HH:MM:SS' format")

State

Base class for implementing experiment states.

This class provides the template for creating states in the experiment state machine. Each state should inherit from this class and implement the required methods.

Attributes:

Name Type Description
state_timer Timer

Timer instance shared across all states

__shared_state Dict[str, Any]

Dictionary containing shared state variables

Source code in src/ethopy/core/experiment.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class State:
    """Base class for implementing experiment states.

    This class provides the template for creating states in the experiment state
    machine. Each state should inherit from this class and implement the required
    methods.

    Attributes:
        state_timer: Timer instance shared across all states
        __shared_state: Dictionary containing shared state variables

    """

    state_timer: Timer = Timer()
    __shared_state: Dict[str, Any] = {}

    def __init__(self, parent: Optional["ExperimentClass"] = None) -> None:
        """Initialize state with optional parent experiment.

        Args:
            parent: Parent experiment instance this state belongs to

        """
        self.__dict__ = self.__shared_state
        if parent:
            self.__dict__.update(parent.__dict__)

    def entry(self) -> None:
        """Execute actions when entering this state."""

    def run(self) -> None:
        """Execute the main state logic."""

    def next(self) -> str:
        """Determine the next state to transition to.

        Returns:
            Name of the next state to transition to

        Raises:
            AssertionError: If next() is not implemented by child class

        """
        raise AssertionError("next not implemented")

    def exit(self) -> None:
        """Execute actions when exiting this state."""

__init__(parent=None)

Initialize state with optional parent experiment.

Parameters:

Name Type Description Default
parent Optional[ExperimentClass]

Parent experiment instance this state belongs to

None
Source code in src/ethopy/core/experiment.py
53
54
55
56
57
58
59
60
61
62
def __init__(self, parent: Optional["ExperimentClass"] = None) -> None:
    """Initialize state with optional parent experiment.

    Args:
        parent: Parent experiment instance this state belongs to

    """
    self.__dict__ = self.__shared_state
    if parent:
        self.__dict__.update(parent.__dict__)

entry()

Execute actions when entering this state.

Source code in src/ethopy/core/experiment.py
64
65
def entry(self) -> None:
    """Execute actions when entering this state."""

exit()

Execute actions when exiting this state.

Source code in src/ethopy/core/experiment.py
82
83
def exit(self) -> None:
    """Execute actions when exiting this state."""

next()

Determine the next state to transition to.

Returns:

Type Description
str

Name of the next state to transition to

Raises:

Type Description
AssertionError

If next() is not implemented by child class

Source code in src/ethopy/core/experiment.py
70
71
72
73
74
75
76
77
78
79
80
def next(self) -> str:
    """Determine the next state to transition to.

    Returns:
        Name of the next state to transition to

    Raises:
        AssertionError: If next() is not implemented by child class

    """
    raise AssertionError("next not implemented")

run()

Execute the main state logic.

Source code in src/ethopy/core/experiment.py
67
68
def run(self) -> None:
    """Execute the main state logic."""

StateMachine

State machine implementation for experiment control flow.

Manages transitions between experiment states and ensures proper execution of state entry/exit hooks. The state machine runs until it reaches the exit state.

Attributes:

Name Type Description
states Dict[str, State]

Mapping of state names to state instances

futureState State

Next state to transition to

currentState State

Currently executing state

exitState State

Final state that ends the state machine

Source code in src/ethopy/core/experiment.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class StateMachine:
    """State machine implementation for experiment control flow.

    Manages transitions between experiment states and ensures proper execution
    of state entry/exit hooks. The state machine runs until it reaches the exit
    state.

    Attributes:
        states (Dict[str, State]): Mapping of state names to state instances
        futureState (State): Next state to transition to
        currentState (State): Currently executing state
        exitState (State): Final state that ends the state machine

    """

    def __init__(self, states: Dict[str, State]) -> None:
        """Initialize the state machine.

        Args:
            states: Dictionary mapping state names to state instances

        Raises:
            ValueError: If required states (Entry, Exit) are missing

        """
        if "Entry" not in states or "Exit" not in states:
            raise ValueError("StateMachine requires Entry and Exit states")

        self.states = states
        self.futureState = states["Entry"]
        self.currentState = states["Entry"]
        self.exitState = states["Exit"]

    # # # # Main state loop # # # # #
    def run(self) -> None:
        """Execute the state machine until reaching exit state.

        The machine will:
        1. Check for state transition
        2. Execute exit hook of current state if transitioning
        3. Execute entry hook of new state if transitioning
        4. Execute the current state's main logic
        5. Determine next state

        Raises:
            KeyError: If a state requests transition to non-existent state
            RuntimeError: If a state's next() method raises an exception

        """
        try:
            while self.futureState != self.exitState:
                if self.currentState != self.futureState:
                    self.currentState.exit()
                    self.currentState = self.futureState
                    self.currentState.entry()

                self.currentState.run()

                next_state = self.currentState.next()
                if next_state not in self.states:
                    raise KeyError(f"Invalid state transition to: {next_state}")

                self.futureState = self.states[next_state]

            self.currentState.exit()
            self.exitState.run()

        except Exception as e:
            raise RuntimeError(
                f"""State machine error in state
                    {self.currentState.__class__.__name__}: {str(e)}"""
            ) from e

__init__(states)

Initialize the state machine.

Parameters:

Name Type Description Default
states Dict[str, State]

Dictionary mapping state names to state instances

required

Raises:

Type Description
ValueError

If required states (Entry, Exit) are missing

Source code in src/ethopy/core/experiment.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __init__(self, states: Dict[str, State]) -> None:
    """Initialize the state machine.

    Args:
        states: Dictionary mapping state names to state instances

    Raises:
        ValueError: If required states (Entry, Exit) are missing

    """
    if "Entry" not in states or "Exit" not in states:
        raise ValueError("StateMachine requires Entry and Exit states")

    self.states = states
    self.futureState = states["Entry"]
    self.currentState = states["Entry"]
    self.exitState = states["Exit"]

run()

Execute the state machine until reaching exit state.

The machine will: 1. Check for state transition 2. Execute exit hook of current state if transitioning 3. Execute entry hook of new state if transitioning 4. Execute the current state's main logic 5. Determine next state

Raises:

Type Description
KeyError

If a state requests transition to non-existent state

RuntimeError

If a state's next() method raises an exception

Source code in src/ethopy/core/experiment.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def run(self) -> None:
    """Execute the state machine until reaching exit state.

    The machine will:
    1. Check for state transition
    2. Execute exit hook of current state if transitioning
    3. Execute entry hook of new state if transitioning
    4. Execute the current state's main logic
    5. Determine next state

    Raises:
        KeyError: If a state requests transition to non-existent state
        RuntimeError: If a state's next() method raises an exception

    """
    try:
        while self.futureState != self.exitState:
            if self.currentState != self.futureState:
                self.currentState.exit()
                self.currentState = self.futureState
                self.currentState.entry()

            self.currentState.run()

            next_state = self.currentState.next()
            if next_state not in self.states:
                raise KeyError(f"Invalid state transition to: {next_state}")

            self.futureState = self.states[next_state]

        self.currentState.exit()
        self.exitState.run()

    except Exception as e:
        raise RuntimeError(
            f"""State machine error in state
                {self.currentState.__class__.__name__}: {str(e)}"""
        ) from e