Middlewares

Add functionality to handlers.

Introduction

Middlewares wrap handlers with functionality that is important, but not relevant for the primary handler’s logic. Examples include retrying the handler after an error was returned, or recovering from panic in the handler and capturing the stacktrace.

Middlewares wrap the handler function like this:

Full source: github.com/ThreeDotsLabs/watermill/message/router.go

// ...
// HandlerMiddleware allows us to write something like decorators to HandlerFunc.
// It can execute something before handler (for example: modify consumed message)
// or after (modify produced messages, ack/nack on consumed message, handle errors, logging, etc.).
//
// It can be attached to the router by using `AddMiddleware` method.
//
// Example:
//
//	func ExampleMiddleware(h message.HandlerFunc) message.HandlerFunc {
//		return func(message *message.Message) ([]*message.Message, error) {
//			fmt.Println("executed before handler")
//			producedMessages, err := h(message)
//			fmt.Println("executed after handler")
//
//			return producedMessages, err
//		}
//	}
type HandlerMiddleware func(h HandlerFunc) HandlerFunc
// ...

Usage

Middlewares can be executed for all as well as for a specific handler in a router. When middleware is added directly to a router it will be executed for all of handlers provided for a router. If a middleware should be executed only for a specific handler, it needs to be added to handler in the router.

Example usage is shown below:

Full source: github.com/ThreeDotsLabs/watermill/_examples/basic/3-router/main.go

// ...
	router, err := message.NewRouter(message.RouterConfig{}, logger)
	if err != nil {
		panic(err)
	}

	// SignalsHandler will gracefully shutdown Router when SIGTERM is received.
	// You can also close the router by just calling `r.Close()`.
	router.AddPlugin(plugin.SignalsHandler)

	// Router level middleware are executed for every message sent to the router
	router.AddMiddleware(
		// CorrelationID will copy the correlation id from the incoming message's metadata to the produced messages
		middleware.CorrelationID,

		// The handler function is retried if it returns an error.
		// After MaxRetries, the message is Nacked and it's up to the PubSub to resend it.
		middleware.Retry{
			MaxRetries:      3,
			InitialInterval: time.Millisecond * 100,
			Logger:          logger,
		}.Middleware,

		// Recoverer handles panics from handlers.
		// In this case, it passes them as errors to the Retry middleware.
		middleware.Recoverer,
	)

	// For simplicity, we are using the gochannel Pub/Sub here,
	// You can replace it with any Pub/Sub implementation, it will work the same.
	pubSub := gochannel.NewGoChannel(gochannel.Config{}, logger)

	// Producing some incoming messages in background
	go publishMessages(pubSub)

	// AddHandler returns a handler which can be used to add handler level middleware
	// or to stop handler.
	handler := router.AddHandler(
		"struct_handler",          // handler name, must be unique
		"incoming_messages_topic", // topic from which we will read events
		pubSub,
		"outgoing_messages_topic", // topic to which we will publish events
		pubSub,
		structHandler{}.Handler,
	)

	// Handler level middleware is only executed for a specific handler
	// Such middleware can be added the same way the router level ones
	handler.AddMiddleware(func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			log.Println("executing handler specific middleware for ", message.UUID)

			return h(message)
		}
	})

	// just for debug, we are printing all messages received on `incoming_messages_topic`
	router.AddNoPublisherHandler(
		"print_incoming_messages",
		"incoming_messages_topic",
		pubSub,
		printMessages,
	)

	// just for debug, we are printing all events sent to `outgoing_messages_topic`
	router.AddNoPublisherHandler(
		"print_outgoing_messages",
		"outgoing_messages_topic",
		pubSub,
		printMessages,
	)

	// Now that all handlers are registered, we're running the Router.
	// Run is blocking while the router is running.
// ...

Available middlewares

Below are the middlewares provided by Watermill and ready to use. You can also easily implement your own. For example, if you’d like to store every received message in some kind of log, it’s the best way to do it.

Circuit Breaker

// CircuitBreaker is a middleware that wraps the handler in a circuit breaker.
// Based on the configuration, the circuit breaker will fail fast if the handler keeps returning errors.
// This is useful for preventing cascading failures.
type CircuitBreaker struct {
	cb *gobreaker.CircuitBreaker
}
// NewCircuitBreaker returns a new CircuitBreaker middleware.
// Refer to the gobreaker documentation for the available settings.
func NewCircuitBreaker(settings gobreaker.Settings) CircuitBreaker {
	return CircuitBreaker{
		cb: gobreaker.NewCircuitBreaker(settings),
	}
}
// Middleware returns the CircuitBreaker middleware.
func (c CircuitBreaker) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		out, err := c.cb.Execute(func() (interface{}, error) {
			return h(msg)
		})

		var result []*message.Message
		if out != nil {
			result = out.([]*message.Message)
		}

		return result, err
	}
}

Correlation

// SetCorrelationID sets a correlation ID for the message.
//
// SetCorrelationID should be called when the message enters the system.
// When message is produced in a request (for example HTTP),
// message correlation ID should be the same as the request's correlation ID.
func SetCorrelationID(id string, msg *message.Message) {
	if MessageCorrelationID(msg) != "" {
		return
	}

	msg.Metadata.Set(CorrelationIDMetadataKey, id)
}
// MessageCorrelationID returns correlation ID from the message.
func MessageCorrelationID(message *message.Message) string {
	return message.Metadata.Get(CorrelationIDMetadataKey)
}
// CorrelationID adds correlation ID to all messages produced by the handler.
// ID is based on ID from message received by handler.
//
// To make CorrelationID working correctly, SetCorrelationID must be called to first message entering the system.
func CorrelationID(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		producedMessages, err := h(message)

		correlationID := MessageCorrelationID(message)
		for _, msg := range producedMessages {
			SetCorrelationID(correlationID, msg)
		}

		return producedMessages, err
	}
}

Deduplicator

// Deduplicator drops similar messages if they are present
// in a [ExpiringKeyRepository]. The similarity is determined
// by a [MessageHasher]. Time out is applied to repository
// operations using [context.WithTimeout].
//
// Call [Deduplicator.Middleware] for a new middleware
// or [Deduplicator.Decorator] for a [message.PublisherDecorator].
//
// KeyFactory defaults to [NewMessageHasherAdler32] with read
// limit  set to [math.MaxInt64] for fast tagging.
// Use [NewMessageHasherSHA256] for minimal collisions.
//
// Repository defaults to [NewMapExpiringKeyRepository] with one
// minute retention window. This default setting is performant
// but **does not support distributed operations**. If you
// implement a [ExpiringKeyRepository] backed by Redis,
// please submit a pull request.
//
// Timeout defaults to one minute. If lower than
// five milliseconds, it is set to five milliseconds.
//
// [ExpiringKeyRepository] must expire values
// in a certain time window. If there is no expiration, only one
// unique message will be ever delivered as long as the repository
// keeps its state.
type Deduplicator struct {
	KeyFactory MessageHasher
	Repository ExpiringKeyRepository
	Timeout    time.Duration
}
// IsDuplicate returns true if the message hash tag calculated
// using a [MessageHasher] was seen in deduplication time window.
func (d *Deduplicator) IsDuplicate(m *message.Message) (bool, error) {
	key, err := d.KeyFactory(m)
	if err != nil {
		return false, err
	}
	ctx, cancel := context.WithTimeout(m.Context(), d.Timeout)
	defer cancel()
	return d.Repository.IsDuplicate(ctx, key)
}
// Middleware returns the [message.HandlerMiddleware]
// that drops similar messages in a given time window.
func (d *Deduplicator) Middleware(h message.HandlerFunc) message.HandlerFunc {
	d = applyDefaultsToDeduplicator(d)
	return func(msg *message.Message) ([]*message.Message, error) {
		isDuplicate, err := d.IsDuplicate(msg)
		if err != nil {
			return nil, err
		}
		if isDuplicate {
			return nil, nil
		}
		return h(msg)
	}
}
// NewMapExpiringKeyRepository returns a memory store
// backed by a regular hash map protected by
// a [sync.Mutex]. The state **cannot be shared or synchronized
// between instances** by design for performance.
//
// If you need to drop duplicate messages by orchestration,
// implement [ExpiringKeyRepository] interface backed by Redis
// or similar.
//
// Window specifies the minimum duration of how long the
// duplicate tags are remembered for. Real duration can
// extend up to 50% longer because it depends on the
// clean up cycle.
func NewMapExpiringKeyRepository(window time.Duration) (ExpiringKeyRepository, error) {
	if window < time.Millisecond {
		return nil, errors.New("deduplication window of less than a millisecond is impractical")
	}

	kr := &mapExpiringKeyRepository{
		window: window,
		mu:     &sync.Mutex{},
		tags:   make(map[string]time.Time),
	}
	go kr.cleanOutLoop(context.Background(), time.NewTicker(window/2))
	return kr, nil
}
// Len returns the number of known tags that have not been
// cleaned out yet.
func (kr *mapExpiringKeyRepository) Len() (count int) {
	kr.mu.Lock()
	count = len(kr.tags)
	kr.mu.Unlock()
	return
}
// NewMessageHasherAdler32 generates message hashes using a fast
// Adler-32 checksum of the [message.Message] body. Read
// limit specifies how many bytes of the message are
// used for calculating the hash.
//
// Lower limit improves performance but results in more false
// positives. Read limit must be greater than
// [MessageHasherReadLimitMinimum].
func NewMessageHasherAdler32(readLimit int64) MessageHasher {
	if readLimit < MessageHasherReadLimitMinimum {
		readLimit = MessageHasherReadLimitMinimum
	}
	return func(m *message.Message) (string, error) {
		h := adler32.New()
		_, err := io.CopyN(h, bytes.NewReader(m.Payload), readLimit)
		if err != nil && err != io.EOF {
			return "", err
		}
		return string(h.Sum(nil)), nil
	}
}
// NewMessageHasherSHA256 generates message hashes using a slower
// but more resilient hashing of the [message.Message] body. Read
// limit specifies how many bytes of the message are
// used for calculating the hash.
//
// Lower limit improves performance but results in more false
// positives. Read limit must be greater than
// [MessageHasherReadLimitMinimum].
func NewMessageHasherSHA256(readLimit int64) MessageHasher {
	if readLimit < MessageHasherReadLimitMinimum {
		readLimit = MessageHasherReadLimitMinimum
	}

	return func(m *message.Message) (string, error) {
		h := sha256.New()
		_, err := io.CopyN(h, bytes.NewReader(m.Payload), readLimit)
		if err != nil && err != io.EOF {
			return "", err
		}
		return string(h.Sum(nil)), nil
	}
}
// NewMessageHasherFromMetadataField looks for a hash value
// inside message metadata instead of calculating a new one.
// Useful if a [MessageHasher] was applied in a previous
// [message.HandlerFunc].
func NewMessageHasherFromMetadataField(field string) MessageHasher {
	return func(m *message.Message) (string, error) {
		fromMetadata, ok := m.Metadata[field]
		if ok {
			return fromMetadata, nil
		}
		return "", fmt.Errorf("cannot recover hash value from metadata of message #%s: field %q is absent", m.UUID, field)
	}
}
// PublisherDecorator returns a decorator that
// acknowledges and drops every [message.Message] that
// was recognized by a [Deduplicator].
//
// The returned decorator provides the same functionality
// to a [message.Publisher] as [Deduplicator.Middleware]
// to a [message.Router].
func (d *Deduplicator) PublisherDecorator() message.PublisherDecorator {
	return func(pub message.Publisher) (message.Publisher, error) {
		if pub == nil {
			return nil, errors.New("cannot decorate a <nil> publisher")
		}

		return &deduplicatingPublisherDecorator{
			Publisher:    pub,
			deduplicator: applyDefaultsToDeduplicator(d),
		}, nil
	}
}

Duplicator

// Duplicator is processing messages twice, to ensure that the endpoint is idempotent.
func Duplicator(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		firstProducedMessages, firstErr := h(msg)
		if firstErr != nil {
			return nil, firstErr
		}

		secondProducedMessages, secondErr := h(msg)
		if secondErr != nil {
			return nil, secondErr
		}

		return append(firstProducedMessages, secondProducedMessages...), nil
	}
}

Ignore Errors

// IgnoreErrors provides a middleware that makes the handler ignore some explicitly whitelisted errors.
type IgnoreErrors struct {
	ignoredErrors map[string]struct{}
}
// NewIgnoreErrors creates a new IgnoreErrors middleware.
func NewIgnoreErrors(errs []error) IgnoreErrors {
	errsMap := make(map[string]struct{}, len(errs))

	for _, err := range errs {
		errsMap[err.Error()] = struct{}{}
	}

	return IgnoreErrors{errsMap}
}
// Middleware returns the IgnoreErrors middleware.
func (i IgnoreErrors) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		events, err := h(msg)
		if err != nil {
			if _, ok := i.ignoredErrors[errors.Cause(err).Error()]; ok {
				return events, nil
			}

			return events, err
		}

		return events, nil
	}
}

Instant Ack

// InstantAck makes the handler instantly acknowledge the incoming message, regardless of any errors.
// It may be used to gain throughput, but at a cost:
// If you had exactly-once delivery, you may expect at-least-once instead.
// If you had ordered messages, the ordering might be broken.
func InstantAck(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		message.Ack()
		return h(message)
	}
}

Poison

// PoisonQueue provides a middleware that salvages unprocessable messages and published them on a separate topic.
// The main middleware chain then continues on, business as usual.
func PoisonQueue(pub message.Publisher, topic string) (message.HandlerMiddleware, error) {
	if topic == "" {
		return nil, ErrInvalidPoisonQueueTopic
	}

	pq := poisonQueue{
		topic: topic,
		pub:   pub,
		shouldGoToPoisonQueue: func(err error) bool {
			return true
		},
	}

	return pq.Middleware, nil
}
// PoisonQueueWithFilter is just like PoisonQueue, but accepts a function that decides which errors qualify for the poison queue.
func PoisonQueueWithFilter(pub message.Publisher, topic string, shouldGoToPoisonQueue func(err error) bool) (message.HandlerMiddleware, error) {
	if topic == "" {
		return nil, ErrInvalidPoisonQueueTopic
	}

	pq := poisonQueue{
		topic: topic,
		pub:   pub,

		shouldGoToPoisonQueue: shouldGoToPoisonQueue,
	}

	return pq.Middleware, nil
}

Randomfail

// RandomFail makes the handler fail with an error based on random chance. Error probability should be in the range (0,1).
func RandomFail(errorProbability float32) message.HandlerMiddleware {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			if shouldFail(errorProbability) {
				return nil, errors.New("random fail occurred")
			}

			return h(message)
		}
	}
}
// RandomPanic makes the handler panic based on random chance. Panic probability should be in the range (0,1).
func RandomPanic(panicProbability float32) message.HandlerMiddleware {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(message *message.Message) ([]*message.Message, error) {
			if shouldFail(panicProbability) {
				panic("random panic occurred")
			}

			return h(message)
		}
	}
}

Recoverer

// RecoveredPanicError holds the recovered panic's error along with the stacktrace.
type RecoveredPanicError struct {
	V          interface{}
	Stacktrace string
}
// Recoverer recovers from any panic in the handler and appends RecoveredPanicError with the stacktrace
// to any error returned from the handler.
func Recoverer(h message.HandlerFunc) message.HandlerFunc {
	return func(event *message.Message) (events []*message.Message, err error) {
		panicked := true

		defer func() {
			if r := recover(); r != nil || panicked {
				err = errors.WithStack(RecoveredPanicError{V: r, Stacktrace: string(debug.Stack())})
			}
		}()

		events, err = h(event)
		panicked = false
		return events, err
	}
}

Retry

// Retry provides a middleware that retries the handler if errors are returned.
// The retry behaviour is configurable, with exponential backoff and maximum elapsed time.
type Retry struct {
	// MaxRetries is maximum number of times a retry will be attempted.
	MaxRetries int

	// InitialInterval is the first interval between retries. Subsequent intervals will be scaled by Multiplier.
	InitialInterval time.Duration
	// MaxInterval sets the limit for the exponential backoff of retries. The interval will not be increased beyond MaxInterval.
	MaxInterval time.Duration
	// Multiplier is the factor by which the waiting interval will be multiplied between retries.
	Multiplier float64
	// MaxElapsedTime sets the time limit of how long retries will be attempted. Disabled if 0.
	MaxElapsedTime time.Duration
	// RandomizationFactor randomizes the spread of the backoff times within the interval of:
	// [currentInterval * (1 - randomization_factor), currentInterval * (1 + randomization_factor)].
	RandomizationFactor float64

	// OnRetryHook is an optional function that will be executed on each retry attempt.
	// The number of the current retry is passed as retryNum,
	OnRetryHook func(retryNum int, delay time.Duration)

	Logger watermill.LoggerAdapter
}
// Middleware returns the Retry middleware.
func (r Retry) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(msg *message.Message) ([]*message.Message, error) {
		producedMessages, err := h(msg)
		if err == nil {
			return producedMessages, nil
		}

		expBackoff := backoff.NewExponentialBackOff()
		expBackoff.InitialInterval = r.InitialInterval
		expBackoff.MaxInterval = r.MaxInterval
		expBackoff.Multiplier = r.Multiplier
		expBackoff.MaxElapsedTime = r.MaxElapsedTime
		expBackoff.RandomizationFactor = r.RandomizationFactor

		ctx := msg.Context()
		if r.MaxElapsedTime > 0 {
			var cancel func()
			ctx, cancel = context.WithTimeout(ctx, r.MaxElapsedTime)
			defer cancel()
		}

		retryNum := 1
		expBackoff.Reset()
	retryLoop:
		for {
			waitTime := expBackoff.NextBackOff()
			select {
			case <-ctx.Done():
				return producedMessages, err
			case <-time.After(waitTime):
				// go on
			}

			producedMessages, err = h(msg)
			if err == nil {
				return producedMessages, nil
			}

			if r.Logger != nil {
				r.Logger.Error("Error occurred, retrying", err, watermill.LogFields{
					"retry_no":     retryNum,
					"max_retries":  r.MaxRetries,
					"wait_time":    waitTime,
					"elapsed_time": expBackoff.GetElapsedTime(),
				})
			}
			if r.OnRetryHook != nil {
				r.OnRetryHook(retryNum, waitTime)
			}

			retryNum++
			if retryNum > r.MaxRetries {
				break retryLoop
			}
		}

		return nil, err
	}
}

Throttle

// Throttle provides a middleware that limits the amount of messages processed per unit of time.
// This may be done e.g. to prevent excessive load caused by running a handler on a long queue of unprocessed messages.
type Throttle struct {
	ticker *time.Ticker
}
// NewThrottle creates a new Throttle middleware.
// Example duration and count: NewThrottle(10, time.Second) for 10 messages per second
func NewThrottle(count int64, duration time.Duration) *Throttle {
	return &Throttle{
		ticker: time.NewTicker(duration / time.Duration(count)),
	}
}
// Middleware returns the Throttle middleware.
func (t Throttle) Middleware(h message.HandlerFunc) message.HandlerFunc {
	return func(message *message.Message) ([]*message.Message, error) {
		// throttle is shared by multiple handlers, which will wait for their "tick"
		<-t.ticker.C

		return h(message)
	}
}

Timeout

// Timeout makes the handler cancel the incoming message's context after a specified time.
// Any timeout-sensitive functionality of the handler should listen on msg.Context().Done() to know when to fail.
func Timeout(timeout time.Duration) func(message.HandlerFunc) message.HandlerFunc {
	return func(h message.HandlerFunc) message.HandlerFunc {
		return func(msg *message.Message) ([]*message.Message, error) {
			ctx, cancel := context.WithTimeout(msg.Context(), timeout)
			defer func() {
				cancel()
			}()

			msg.SetContext(ctx)
			return h(msg)
		}
	}
}