public abstract class ParameterEstimator extends TypedAtomicActor
This actor implements the Expectation-Maximization(EM) algorithm for parameter estimation in graphical stochastic models. Two types of fundamental types of Bayesian Network models: Mixture Model(MM) and Hidden Markov Model(HMM) are supported. The input is an array of observations of arbitrary length and the outputs are the parameter estimates for the chosen model.
The output ports reflect the parameter estimates of Gaussian MM or HMM. The Mixture Model is parameterized by M states, each distributed according to a distribution specified by the emissionDistribution parameter. Currently, the actor supports Gaussian emissions. The mean is a double array output containing the mean estimates and sigma is a double array output containing standard deviation estimates of each mixture component. If the modelType is HMM, then an additional output, transitionMatrix is provided, which is an estimate of the transition matrix governing the Markovian process representing the hidden state evolution. If the modelType is MM, this port outputs a double array with the prior probability estimates of the mixture components.
iterations is the maximum number of EM iterations until the log-likelihood P(observations | model parameters) remain within likelihoodThreshold neighborhood of the previous likelihood estimate. The default likelihood threshold is set to 1E-4 and for precise applications, may be set to a lower positive value. The actor iterates over the parameter estimates using the EM algorithm. If, at any point, the estimates become NaN, the user is notified that the algorithm did not converge and is given the option to randomize initial guesses to reiterate.
References
[1] Jordan, Michael I., et al. An introduction to variational methods for graphical models, Springer Netherlands, 1998.
[2] Bilmes, Jeff A. A gentle tutorial of the EM algorithm and its application to parameter estimation for Gaussian mixture and hidden Markov models. International Computer Science Institute 4.510 (1998): 126.
ObservationClassifier
,
HMMGaussianEstimator
Entity.ContainedObjectsIterator
Modifier and Type | Field and Description |
---|---|
protected double[][] |
_A0
User-defined initial guess array for the state transition matrix.
|
protected double |
_likelihood
The likelihood value of the observations given the current estimates L(x1,....xT | \theta_p).
|
protected double |
_likelihoodThreshold
The likelihood threshold.
|
protected int |
_nIterations
User-defined number of iterations of the alpha-beta recursion.
|
protected int |
_nStates
Number of hidden states in the model.
|
protected int |
_obsDimension
observation dimension.
|
protected double[][] |
_observations
Observation array.
|
protected double[] |
_priorIn
The prior estimates used in the EM iterations.
|
protected double[] |
_priors
Prior distribution on hidden states.
|
protected boolean |
_randomize
randomize the initial guess vectors or not.
|
protected double[][] |
_transitionMatrix
Initial guess array for the state transition matrix for the Alpha-Beta Recursion.
|
Parameter |
A0
The user-provided initial guess of the transition probability matrix.
|
TypedIOPort |
input
The input port that provides the sample observations.
|
protected double |
likelihood
Fitted model likelihood.
|
TypedIOPort |
likelihoodOut
An output port of type Double that contains the likelyhood.
|
Parameter |
likelihoodThreshold
The user-provided threshold on the minimum desired improvement on likelihood per iteration.
|
Parameter |
maxIterations
The user-provided maximum number of allowed iterations of the Alpha-Beta Recursion.
|
protected java.util.HashMap |
newEstimates
Updated parameter sets, used during Expectation-Maximization.
|
Parameter |
nStates
Number of states of the HMM.
|
Parameter |
priorDistribution
The user-provided initial guess on the prior probability distribution.
|
TypedIOPort |
priorEstimates
The vector estimate for the prior distribution on the set of states.
|
Parameter |
randomizeGuessVectors
Boolean that determines whether or not to randomize input guess vectors.
|
TypedIOPort |
transitionMatrix
The transition matrix estimate obtained by iterating over the observation set.
|
_typesValid
_actorFiringListeners, _initializables, _notifyingActorFiring, _stopRequested
_changeListeners, _changeLock, _changeRequests, _debugging, _debugListeners, _deferChangeRequests, _elementName, _isPersistent, _verbose, _workspace, ATTRIBUTES, CLASSNAME, COMPLETE, CONTENTS, DEEP, FULLNAME, LINKS
COMPLETED, NOT_READY, STOP_ITERATING
Constructor and Description |
---|
ParameterEstimator(CompositeEntity container,
java.lang.String name)
Construct an actor with the given container and name.
|
Modifier and Type | Method and Description |
---|---|
protected abstract boolean |
_checkForConvergence(int i)
Check whether the gradient-descent algorithm has converged.
|
protected boolean |
_EMParameterEstimation()
Expectation-Maximization, which internally executes a gradient-descent algorithm
for parameter estimation.
|
protected void |
_initializeArrays()
Initialize arrays to be used in parameter estimation.
|
protected abstract void |
_initializeEMParameters()
Initialize parameters used in ExpectationMaximization here.
|
protected abstract void |
_iterateEM()
One step EM iteration.
|
protected abstract void |
_updateEstimates()
Update parameter estimates.
|
void |
attributeChanged(Attribute attribute)
React to a change in an attribute.
|
java.lang.Object |
clone(Workspace workspace)
Clone the actor into the specified workspace.
|
protected abstract double |
emissionProbability(double[] y,
int hiddenState)
Computes the emission probability.
|
void |
fire()
Do nothing.
|
protected java.util.HashMap |
HMMAlphaBetaRecursion(double[][] y,
double[][] A,
double[] prior,
int nCategories)
Java implementation of the Baum-Welch algorithm (The exact algorithm used here is known as
the Alpha-Gamma Recursion, which is a slightly more convenient version of the well-known
Alpha-Beta recursion) for parameter estimation
and cluster assignment.
|
_containedTypeConstraints, _customTypeConstraints, _defaultTypeConstraints, _fireAt, _fireAt, attributeTypeChanged, clone, isBackwardTypeInferenceEnabled, newPort, typeConstraintList, typeConstraints
_actorFiring, _actorFiring, _declareDelayDependency, addActorFiringListener, addInitializable, connectionsChanged, createReceivers, declareDelayDependency, getCausalityInterface, getDirector, getExecutiveDirector, getManager, initialize, inputPortList, isFireFunctional, isStrict, iterate, newReceiver, outputPortList, postfire, prefire, preinitialize, pruneDependencies, recordFiring, removeActorFiringListener, removeDependency, removeInitializable, setContainer, stop, stopFire, terminate, wrapup
_adjustDeferrals, _checkContainer, _getContainedObject, _propagateExistence, getContainer, instantiate, isAtomic, isOpaque, moveDown, moveToFirst, moveToIndex, moveToLast, moveUp, propagateExistence, setName
_addPort, _description, _exportMoMLContents, _removePort, _validateSettables, connectedPortList, connectedPorts, containedObjectsIterator, getAttribute, getPort, getPorts, linkedRelationList, linkedRelations, portList, removeAllPorts, setClassDefinition, uniqueName
_setParent, exportMoML, getChildren, getElementName, getParent, getPrototypeList, isClassDefinition, isWithinClassDefinition
_addAttribute, _adjustOverride, _attachText, _cloneFixAttributeFields, _containedDecorators, _copyChangeRequestList, _debug, _debug, _debug, _debug, _debug, _executeChangeRequests, _getIndentPrefix, _isMoMLSuppressed, _markContentsDerived, _notifyHierarchyListenersAfterChange, _notifyHierarchyListenersBeforeChange, _propagateValue, _removeAttribute, _splitName, _stripNumericSuffix, addChangeListener, addDebugListener, addHierarchyListener, attributeDeleted, attributeList, attributeList, decorators, deepContains, depthInHierarchy, description, description, event, executeChangeRequests, exportMoML, exportMoML, exportMoML, exportMoML, exportMoMLPlain, getAttribute, getAttributes, getChangeListeners, getClassName, getDecoratorAttribute, getDecoratorAttributes, getDerivedLevel, getDerivedList, getDisplayName, getFullName, getModelErrorHandler, getName, getName, getSource, handleModelError, isDeferringChangeRequests, isOverridden, isPersistent, lazyContainedObjectsIterator, message, notifyOfNameChange, propagateValue, propagateValues, removeAttribute, removeChangeListener, removeDebugListener, removeHierarchyListener, requestChange, setClassName, setDeferringChangeRequests, setDerivedLevel, setDisplayName, setModelErrorHandler, setPersistent, setSource, sortContainedObjects, toplevel, toString, validateSettables, workspace
equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
createReceivers, getCausalityInterface, getDirector, getExecutiveDirector, getManager, inputPortList, newReceiver, outputPortList
isFireFunctional, isStrict, iterate, postfire, prefire, stop, stopFire, terminate
addInitializable, initialize, preinitialize, removeInitializable, wrapup
description, getContainer, getDisplayName, getFullName, getName, getName, setName
getDerivedLevel, getDerivedList, propagateValue
public Parameter A0
public Parameter likelihoodThreshold
public Parameter maxIterations
public Parameter nStates
public Parameter randomizeGuessVectors
public Parameter priorDistribution
public TypedIOPort input
public TypedIOPort likelihoodOut
public TypedIOPort priorEstimates
public TypedIOPort transitionMatrix
protected double[][] _A0
protected double _likelihood
protected double _likelihoodThreshold
protected int _nIterations
protected int _obsDimension
protected int _nStates
protected double[][] _observations
protected double[] _priors
protected double[] _priorIn
protected boolean _randomize
protected double[][] _transitionMatrix
protected java.util.HashMap newEstimates
protected double likelihood
public ParameterEstimator(CompositeEntity container, java.lang.String name) throws NameDuplicationException, IllegalActionException
container
- The container.name
- The name of this actorIllegalActionException
- If the actor cannot be contained
by the proposed container.NameDuplicationException
- If the container already has an
actor with this name.public void attributeChanged(Attribute attribute) throws IllegalActionException
NamedObj
attributeChanged
in class NamedObj
attribute
- The attribute that changed.IllegalActionException
- If the change is not acceptable
to this container (not thrown in this base class).public java.lang.Object clone(Workspace workspace) throws java.lang.CloneNotSupportedException
TypedAtomicActor
clone
in class TypedAtomicActor
workspace
- The workspace for the new object.java.lang.CloneNotSupportedException
- If a derived class contains
an attribute that cannot be cloned.NamedObj.exportMoML(Writer, int, String)
,
NamedObj.setDeferringChangeRequests(boolean)
public void fire() throws IllegalActionException
AtomicActor
fire
in interface Executable
fire
in class AtomicActor<TypedIOPort>
IllegalActionException
- Not thrown in this base class.protected boolean _EMParameterEstimation() throws IllegalActionException
IllegalActionException
protected abstract double emissionProbability(double[] y, int hiddenState) throws IllegalActionException
y
- input observationhiddenState
- index of hidden stateIllegalActionException
protected void _initializeArrays() throws IllegalActionException
IllegalActionException
- Not thrown in this base class.protected abstract void _initializeEMParameters()
protected abstract void _iterateEM() throws IllegalActionException
IllegalActionException
- If there is a problem.protected abstract boolean _checkForConvergence(int i) throws IllegalActionException
i
- Current iteration indexIllegalActionException
- If there is a problemprotected abstract void _updateEstimates()
protected java.util.HashMap HMMAlphaBetaRecursion(double[][] y, double[][] A, double[] prior, int nCategories) throws IllegalActionException
y
- input observation streamA
- transition probability matrix guessprior
- prior state distribution guessnCategories
- number of categories in the multinomial distribution, where appliesIllegalActionException