Circuit Breaker

// CircuitBreaker is a middleware that wraps the handler in a circuit breaker.
// Based on the configuration, the circuit breaker will fail fast if the handler keeps returning errors.
// This is useful for preventing cascading failures.
type CircuitBreaker struct {
	cb *gobreaker.CircuitBreaker
}
// NewCircuitBreaker returns a new CircuitBreaker middleware.
// Refer to the gobreaker documentation for the available settings.
func NewCircuitBreaker(settings gobreaker.Settings) CircuitBreaker {
	return CircuitBreaker{
		cb: gobreaker.NewCircuitBreaker(settings),
	}
}
// Middleware returns the CircuitBreaker middleware.
func (c CircuitBreaker) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		out, err := c.cb.Execute(func() (interface{}, error) {
			return h(msg)
		})

		var result []*message.Message
		if out != nil {
			result = out.([]*message.Message)
		}

		return result, err
	}
}

Correlation

// SetCorrelationID sets a correlation ID for the message.
//
// SetCorrelationID should be called when the message enters the system.
// When message is produced in a request (for example HTTP),
// message correlation ID should be the same as the request's correlation ID.
func SetCorrelationID(id string, msg *message.Message) {
	if MessageCorrelationID(msg) != "" {
		return
	}

	msg.Metadata.Set(CorrelationIDMetadataKey, id)
}
// MessageCorrelationID returns correlation ID from the message.
func MessageCorrelationID(message *message.Message) string {
	return message.Metadata.Get(CorrelationIDMetadataKey)
}
// CorrelationID adds correlation ID to all messages produced by the handler.
// ID is based on ID from message received by handler.
//
// To make CorrelationID working correctly, SetCorrelationID must be called to first message entering the system.
func CorrelationID(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		producedMessages, err := h(message)

		correlationID := MessageCorrelationID(message)
		for _, msg := range producedMessages {
			SetCorrelationID(correlationID, msg)
		}

		return producedMessages, err
	}
}

Deduplicator

// Deduplicator drops similar messages if they are present
// in a [ExpiringKeyRepository]. The similarity is determined
// by a [MessageHasher]. Time out is applied to repository
// operations using [context.WithTimeout].
//
// Call [Deduplicator.Middleware] for a new middleware
// or [Deduplicator.Decorator] for a [message.PublisherDecorator].
//
// KeyFactory defaults to [NewMessageHasherAdler32] with read
// limit  set to [math.MaxInt64] for fast tagging.
// Use [NewMessageHasherSHA256] for minimal collisions.
//
// Repository defaults to [NewMapExpiringKeyRepository] with one
// minute retention window. This default setting is performant
// but **does not support distributed operations**. If you
// implement a [ExpiringKeyRepository] backed by Redis,
// please submit a pull request.
//
// Timeout defaults to one minute. If lower than
// five milliseconds, it is set to five milliseconds.
//
// [ExpiringKeyRepository] must expire values
// in a certain time window. If there is no expiration, only one
// unique message will be ever delivered as long as the repository
// keeps its state.
type Deduplicator struct {
	KeyFactory MessageHasher
	Repository ExpiringKeyRepository
	Timeout    time.Duration
}
// IsDuplicate returns true if the message hash tag calculated
// using a [MessageHasher] was seen in deduplication time window.
func (d *Deduplicator) IsDuplicate(m *message.Message) (bool, error) {
	key, err := d.KeyFactory(m)
	if err != nil {
		return false, err
	}
	ctx, cancel := context.WithTimeout(m.Context(), d.Timeout)
	defer cancel()
	return d.Repository.IsDuplicate(ctx, key)
}
// Middleware returns the [message.HandlerMiddleware]
// that drops similar messages in a given time window.
func (d *Deduplicator) Middleware(h message.HandlerFunc) message.HandlerFunc {
	d = applyDefaultsToDeduplicator(d)
	return func(msg *message.Message) ([]*message.Message, error) {
		isDuplicate, err := d.IsDuplicate(msg)
		if err != nil {
			return nil, err
		}
		if isDuplicate {
			return nil, nil
		}
		return h(msg)
	}
}
// NewMapExpiringKeyRepository returns a memory store
// backed by a regular hash map protected by
// a [sync.Mutex]. The state **cannot be shared or synchronized
// between instances** by design for performance.
//
// If you need to drop duplicate messages by orchestration,
// implement [ExpiringKeyRepository] interface backed by Redis
// or similar.
//
// Window specifies the minimum duration of how long the
// duplicate tags are remembered for. Real duration can
// extend up to 50% longer because it depends on the
// clean up cycle.
func NewMapExpiringKeyRepository(window time.Duration) (ExpiringKeyRepository, error) {
	if window < time.Millisecond {
		return nil, errors.New("deduplication window of less than a millisecond is impractical")
	}

	kr := &mapExpiringKeyRepository{
		window: window,
		mu:     &sync.Mutex{},
		tags:   make(map[string]time.Time),
	}
	go kr.cleanOutLoop(context.Background(), time.NewTicker(window/2))
	return kr, nil
}
// Len returns the number of known tags that have not been
// cleaned out yet.
func (kr *mapExpiringKeyRepository) Len() (count int) {
	kr.mu.Lock()
	count = len(kr.tags)
	kr.mu.Unlock()
	return
}
// NewMessageHasherAdler32 generates message hashes using a fast
// Adler-32 checksum of the [message.Message] body. Read
// limit specifies how many bytes of the message are
// used for calculating the hash.
//
// Lower limit improves performance but results in more false
// positives. Read limit must be greater than
// [MessageHasherReadLimitMinimum].
func NewMessageHasherAdler32(readLimit int64) MessageHasher {
	if readLimit < MessageHasherReadLimitMinimum {
		readLimit = MessageHasherReadLimitMinimum
	}
	return func(m *message.Message) (string, error) {
		h := adler32.New()
		_, err := io.CopyN(h, bytes.NewReader(m.Payload), readLimit)
		if err != nil && err != io.EOF {
			return "", err
		}
		return string(h.Sum(nil)), nil
	}
}
// NewMessageHasherSHA256 generates message hashes using a slower
// but more resilient hashing of the [message.Message] body. Read
// limit specifies how many bytes of the message are
// used for calculating the hash.
//
// Lower limit improves performance but results in more false
// positives. Read limit must be greater than
// [MessageHasherReadLimitMinimum].
func NewMessageHasherSHA256(readLimit int64) MessageHasher {
	if readLimit < MessageHasherReadLimitMinimum {
		readLimit = MessageHasherReadLimitMinimum
	}

	return func(m *message.Message) (string, error) {
		h := sha256.New()
		_, err := io.CopyN(h, bytes.NewReader(m.Payload), readLimit)
		if err != nil && err != io.EOF {
			return "", err
		}
		return string(h.Sum(nil)), nil
	}
}
// NewMessageHasherFromMetadataField looks for a hash value
// inside message metadata instead of calculating a new one.
// Useful if a [MessageHasher] was applied in a previous
// [message.HandlerFunc].
func NewMessageHasherFromMetadataField(field string) MessageHasher {
	return func(m *message.Message) (string, error) {
		fromMetadata, ok := m.Metadata[field]
		if ok {
			return fromMetadata, nil
		}
		return "", fmt.Errorf("cannot recover hash value from metadata of message #%s: field %q is absent", m.UUID, field)
	}
}
// PublisherDecorator returns a decorator that
// acknowledges and drops every [message.Message] that
// was recognized by a [Deduplicator].
//
// The returned decorator provides the same functionality
// to a [message.Publisher] as [Deduplicator.Middleware]
// to a [message.Router].
func (d *Deduplicator) PublisherDecorator() message.PublisherDecorator {
	return func(pub message.Publisher) (message.Publisher, error) {
		if pub == nil {
			return nil, errors.New("cannot decorate a <nil> publisher")
		}

		return &deduplicatingPublisherDecorator{
			Publisher:    pub,
			deduplicator: applyDefaultsToDeduplicator(d),
		}, nil
	}
}

Delay On Error

// DelayOnError is a middleware that adds the delay metadata to the message if an error occurs.
//
// IMPORTANT: The delay metadata doesn't cause delays with all Pub/Subs! Using it won't have any effect on Pub/Subs that don't support it.
// See the list of supported Pub/Subs in the documentation: https://watermill.io/advanced/delayed-messages/
type DelayOnError struct {
	// InitialInterval is the first interval between retries. Subsequent intervals will be scaled by Multiplier.
	InitialInterval time.Duration
	// MaxInterval sets the limit for the exponential backoff of retries. The interval will not be increased beyond MaxInterval.
	MaxInterval time.Duration
	// Multiplier is the factor by which the waiting interval will be multiplied between retries.
	Multiplier float64
}

Duplicator

// Duplicator is processing messages twice, to ensure that the endpoint is idempotent.
func Duplicator(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		firstProducedMessages, firstErr := h(msg)
		if firstErr != nil {
			return nil, firstErr
		}

		secondProducedMessages, secondErr := h(msg)
		if secondErr != nil {
			return nil, secondErr
		}

		return append(firstProducedMessages, secondProducedMessages...), nil
	}
}

Ignore Errors

// IgnoreErrors provides a middleware that makes the handler ignore some explicitly whitelisted errors.
type IgnoreErrors struct {
	ignoredErrors map[string]struct{}
}
// NewIgnoreErrors creates a new IgnoreErrors middleware.
func NewIgnoreErrors(errs []error) IgnoreErrors {
	errsMap := make(map[string]struct{}, len(errs))

	for _, err := range errs {
		errsMap[err.Error()] = struct{}{}
	}

	return IgnoreErrors{errsMap}
}
// Middleware returns the IgnoreErrors middleware.
func (i IgnoreErrors) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		events, err := h(msg)
		if err != nil {
			if _, ok := i.ignoredErrors[errors.Cause(err).Error()]; ok {
				return events, nil
			}

			return events, err
		}

		return events, nil
	}
}

Instant Ack

// InstantAck makes the handler instantly acknowledge the incoming message, regardless of any errors.
// It may be used to gain throughput, but at a cost:
// If you had exactly-once delivery, you may expect at-least-once instead.
// If you had ordered messages, the ordering might be broken.
func InstantAck(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		message.Ack()
		return h(message)
	}
}

Poison

// PoisonQueue provides a middleware that salvages unprocessable messages and published them on a separate topic.
// The main middleware chain then continues on, business as usual.
func PoisonQueue(pub message.Publisher, topic string) (message.HandlerMiddleware, error) {
	if topic == "" {
		return nil, ErrInvalidPoisonQueueTopic
	}

	pq := poisonQueue{
		topic: topic,
		pub:   pub,
		shouldGoToPoisonQueue: func(err error) bool {
			return true
		},
	}

	return pq.Middleware, nil
}
// PoisonQueueWithFilter is just like PoisonQueue, but accepts a function that decides which errors qualify for the poison queue.
func PoisonQueueWithFilter(pub message.Publisher, topic string, shouldGoToPoisonQueue func(err error) bool) (message.HandlerMiddleware, error) {
	if topic == "" {
		return nil, ErrInvalidPoisonQueueTopic
	}

	pq := poisonQueue{
		topic: topic,
		pub:   pub,

		shouldGoToPoisonQueue: shouldGoToPoisonQueue,
	}

	return pq.Middleware, nil
}

Randomfail

// RandomFail makes the handler fail with an error based on random chance. Error probability should be in the range (0,1).
func RandomFail(errorProbability float32) message.HandlerMiddleware {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			if shouldFail(errorProbability) {
				return nil, errors.New("random fail occurred")
			}

			return h(message)
		}
	}
}
// RandomPanic makes the handler panic based on random chance. Panic probability should be in the range (0,1).
func RandomPanic(panicProbability float32) message.HandlerMiddleware {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			if shouldFail(panicProbability) {
				panic("random panic occurred")
			}

			return h(message)
		}
	}
}

Recoverer

// RecoveredPanicError holds the recovered panic's error along with the stacktrace.
type RecoveredPanicError struct {
	V          interface{}
	Stacktrace string
}
// Recoverer recovers from any panic in the handler and appends RecoveredPanicError with the stacktrace
// to any error returned from the handler.
func Recoverer(h message.HandlerFunc) message.HandlerFunc {
	return func(event *message.Message) (events []*message.Message, err error) {
		panicked := true

		defer func() {
			if r := recover(); r != nil || panicked {
				err = errors.WithStack(RecoveredPanicError{V: r, Stacktrace: string(debug.Stack())})
			}
		}()

		events, err = h(event)
		panicked = false
		return events, err
	}
}

Retry

// Retry provides a middleware that retries the handler if errors are returned.
// The retry behaviour is configurable, with exponential backoff and maximum elapsed time.
type Retry struct {
	// MaxRetries is maximum number of times a retry will be attempted.
	MaxRetries int

	// InitialInterval is the first interval between retries. Subsequent intervals will be scaled by Multiplier.
	InitialInterval time.Duration
	// MaxInterval sets the limit for the exponential backoff of retries. The interval will not be increased beyond MaxInterval.
	MaxInterval time.Duration
	// Multiplier is the factor by which the waiting interval will be multiplied between retries.
	Multiplier float64
	// MaxElapsedTime sets the time limit of how long retries will be attempted. Disabled if 0.
	MaxElapsedTime time.Duration
	// RandomizationFactor randomizes the spread of the backoff times within the interval of:
	// [currentInterval * (1 - randomization_factor), currentInterval * (1 + randomization_factor)].
	RandomizationFactor float64

	// OnRetryHook is an optional function that will be executed on each retry attempt.
	// The number of the current retry is passed as retryNum,
	OnRetryHook func(retryNum int, delay time.Duration)

	Logger watermill.LoggerAdapter
}
// Middleware returns the Retry middleware.
func (r Retry) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		producedMessages, err := h(msg)
		if err == nil {
			return producedMessages, nil
		}

		expBackoff := backoff.NewExponentialBackOff()
		expBackoff.InitialInterval = r.InitialInterval
		expBackoff.MaxInterval = r.MaxInterval
		expBackoff.Multiplier = r.Multiplier
		expBackoff.MaxElapsedTime = r.MaxElapsedTime
		expBackoff.RandomizationFactor = r.RandomizationFactor

		ctx := msg.Context()
		if r.MaxElapsedTime > 0 {
			var cancel func()
			ctx, cancel = context.WithTimeout(ctx, r.MaxElapsedTime)
			defer cancel()
		}

		retryNum := 1
		expBackoff.Reset()
	retryLoop:
		for {
			waitTime := expBackoff.NextBackOff()
			select {
			case <-ctx.Done():
				return producedMessages, err
			case <-time.After(waitTime):
				// go on
			}

			producedMessages, err = h(msg)
			if err == nil {
				return producedMessages, nil
			}

			if r.Logger != nil {
				r.Logger.Error("Error occurred, retrying", err, watermill.LogFields{
					"retry_no":     retryNum,
					"max_retries":  r.MaxRetries,
					"wait_time":    waitTime,
					"elapsed_time": expBackoff.GetElapsedTime(),
				})
			}
			if r.OnRetryHook != nil {
				r.OnRetryHook(retryNum, waitTime)
			}

			retryNum++
			if retryNum > r.MaxRetries {
				break retryLoop
			}
		}

		return nil, err
	}
}

Throttle

// Throttle provides a middleware that limits the amount of messages processed per unit of time.
// This may be done e.g. to prevent excessive load caused by running a handler on a long queue of unprocessed messages.
type Throttle struct {
	ticker *time.Ticker
}
// NewThrottle creates a new Throttle middleware.
// Example duration and count: NewThrottle(10, time.Second) for 10 messages per second
func NewThrottle(count int64, duration time.Duration) *Throttle {
	return &Throttle{
		ticker: time.NewTicker(duration / time.Duration(count)),
	}
}
// Middleware returns the Throttle middleware.
func (t Throttle) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		// throttle is shared by multiple handlers, which will wait for their "tick"
		<-t.ticker.C

		return h(message)
	}
}

Timeout

// Timeout makes the handler cancel the incoming message's context after a specified time.
// Any timeout-sensitive functionality of the handler should listen on msg.Context().Done() to know when to fail.
func Timeout(timeout time.Duration) func(message.HandlerFunc) message.HandlerFunc {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(msg *message.Message) ([]*message.Message, error) {
			ctx, cancel := context.WithTimeout(msg.Context(), timeout)
			defer func() {
				cancel()
			}()

			msg.SetContext(ctx)
			return h(msg)
		}
	}
}

Check our online hands-on training