backchannel

package
v16.3.1 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Aug 31, 2023 License: MIT Imports: 13 Imported by: 0

Documentation

Overview

Package backchannel implements connection multiplexing that allows for invoking gRPC methods from the server to the client.

gRPC allows only for invoking RPCs from client to the server. Invoking RPCs from the server to the client can be useful in some cases such as tunneling through firewalls. While implementing such a use case would be possible with plain bidirectional streams, the approach has various limitations that force additional work on the user. All messages in a single stream are ordered and processed sequentially. If concurrency is desired, this would require the user to implement their own concurrency handling. Request routing and cancellations would also have to be implemented separately on top of the bidirectional stream.

To do away with these problems, this package provides a multiplexed transport for running two independent gRPC sessions on a single connection. This allows for dialing back to the client from the server to establish another gRPC session where the server and client roles are switched.

The server side uses listenmux to support clients that are unaware of the multiplexing.

Usage:

  1. Implement a ServerFactory, which is simply a function that returns a Server that can serve on the backchannel connection. Plug in the ClientHandshake to the Clientconn via grpc.WithTransportCredentials when dialing. This ensures all connections established by gRPC work with a multiplexing session and have a backchannel Server serving.
  2. Create a *listenmux.Mux and register a *ServerHandshaker with it.
  3. Pass the *listenmux.Mux into the grpc Server using grpc.Creds. The Handshake method is called on each newly established connection that presents the backchannel magic bytes. It dials back to the client's backchannel server. Server makes the backchannel connection's available later via the Registry's Backchannel method. The ID of the peer associated with the current RPC handler can be fetched via GetPeerID. The returned ID can be used to access the correct backchannel connection from the Registry.
Example
package main

import (
	"context"
	"fmt"
	"net"

	"github.com/sirupsen/logrus"
	"gitlab.com/gitlab-org/gitaly/v16/internal/grpc/backchannel"
	"gitlab.com/gitlab-org/gitaly/v16/internal/grpc/listenmux"
	"gitlab.com/gitlab-org/gitaly/v16/proto/go/gitalypb"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

func main() {
	// Open the server's listener.
	ln, err := net.Listen("tcp", "localhost:0")
	if err != nil {
		fmt.Printf("failed to start listener: %v", err)
		return
	}

	// Registry is for storing the open backchannels. It should be passed into the ServerHandshaker
	// which creates the backchannel connections and adds them to the registry. The RPC handlers
	// can use the registry to access available backchannels by their peer ID.
	registry := backchannel.NewRegistry()

	logger := logrus.NewEntry(logrus.New())

	// ServerHandshaker initiates the multiplexing session on the server side. Once that is done,
	// it creates the backchannel connection and stores it into the registry. For each connection,
	// the ServerHandshaker passes down the peer ID via the context. The peer ID identifies a
	// backchannel connection.
	lm := listenmux.New(insecure.NewCredentials())
	lm.Register(backchannel.NewServerHandshaker(logger, registry, nil))

	// Create the server
	srv := grpc.NewServer(
		grpc.Creds(lm),
		grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
			fmt.Println("Gitaly received a transactional mutator")

			backchannelID, err := backchannel.GetPeerID(stream.Context())
			if err == backchannel.ErrNonMultiplexedConnection {
				// This call is from a client that is not multiplexing aware. Client is not
				// Praefect, so no need to perform voting. The client could be for example
				// GitLab calling Gitaly directly.
				fmt.Println("Gitaly responding to a non-multiplexed client")
				return stream.SendMsg(&gitalypb.CreateBranchResponse{})
			} else if err != nil {
				return fmt.Errorf("get peer id: %w", err)
			}

			backchannelConn, err := registry.Backchannel(backchannelID)
			if err != nil {
				return fmt.Errorf("get backchannel: %w", err)
			}

			fmt.Println("Gitaly sending vote to Praefect via backchannel")
			if err := backchannelConn.Invoke(
				stream.Context(), "/Praefect/VoteTransaction",
				&gitalypb.VoteTransactionRequest{}, &gitalypb.VoteTransactionResponse{},
			); err != nil {
				return fmt.Errorf("invoke backchannel: %w", err)
			}
			fmt.Println("Gitaly received vote response via backchannel")

			fmt.Println("Gitaly responding to the transactional mutator")
			return stream.SendMsg(&gitalypb.CreateBranchResponse{})
		}),
	)
	defer srv.Stop()

	// Start the server
	go func() {
		if err := srv.Serve(ln); err != nil {
			fmt.Printf("failed to serve: %v", err)
		}
	}()

	fmt.Printf("Invoke with a multiplexed client:\n\n")
	if err := invokeWithMuxedClient(logger, ln.Addr().String()); err != nil {
		fmt.Printf("failed to invoke with muxed client: %v", err)
		return
	}

	fmt.Printf("\nInvoke with a non-multiplexed client:\n\n")
	if err := invokeWithNormalClient(ln.Addr().String()); err != nil {
		fmt.Printf("failed to invoke with non-muxed client: %v", err)
		return
	}
}

func invokeWithMuxedClient(logger *logrus.Entry, address string) error {
	// clientHandshaker's ClientHandshake gets called on each established connection. The Server returned by the
	// ServerFactory is started on Praefect's end of the connection, which Gitaly can call.
	clientHandshaker := backchannel.NewClientHandshaker(logger, func() backchannel.Server {
		return grpc.NewServer(grpc.UnknownServiceHandler(func(srv interface{}, stream grpc.ServerStream) error {
			fmt.Println("Praefect received vote via backchannel")
			fmt.Println("Praefect responding via backchannel")
			return stream.SendMsg(&gitalypb.VoteTransactionResponse{})
		}))
	}, backchannel.DefaultConfiguration())

	return invokeWithOpts(address, grpc.WithTransportCredentials(clientHandshaker.ClientHandshake(insecure.NewCredentials())))
}

func invokeWithNormalClient(address string) error {
	return invokeWithOpts(address, grpc.WithTransportCredentials(insecure.NewCredentials()))
}

func invokeWithOpts(address string, opts ...grpc.DialOption) error {
	clientConn, err := grpc.Dial(address, opts...)
	if err != nil {
		return fmt.Errorf("dial server: %w", err)
	}

	if err := clientConn.Invoke(context.Background(), "/Gitaly/Mutator", &gitalypb.CreateBranchRequest{}, &gitalypb.CreateBranchResponse{}); err != nil {
		return fmt.Errorf("call server: %w", err)
	}

	if err := clientConn.Close(); err != nil {
		return fmt.Errorf("close clientConn: %w", err)
	}

	return nil
}
Output:

Invoke with a multiplexed client:

Gitaly received a transactional mutator
Gitaly sending vote to Praefect via backchannel
Praefect received vote via backchannel
Praefect responding via backchannel
Gitaly received vote response via backchannel
Gitaly responding to the transactional mutator

Invoke with a non-multiplexed client:

Gitaly received a transactional mutator
Gitaly responding to a non-multiplexed client

Index

Examples

Constants

This section is empty.

Variables

View Source
var ErrNonMultiplexedConnection = errors.New("non-multiplexed connection")

ErrNonMultiplexedConnection is returned when attempting to get the peer id of a non-multiplexed connection.

Functions

func GetYamuxSession

func GetYamuxSession(ctx context.Context) (*yamux.Session, error)

GetYamuxSession gets the yamux session of the current peer connection.

func WithID

func WithID(authInfo credentials.AuthInfo, id ID) credentials.AuthInfo

WithID stores the ID in the provided AuthInfo so it can be later accessed by the RPC handler. GetYamuxSession gets the yamux session of the current peer connection. This is exported to facilitate testing.

Types

type ClientHandshaker

type ClientHandshaker struct {
	// contains filtered or unexported fields
}

ClientHandshaker implements the client side handshake of the multiplexed connection.

func NewClientHandshaker

func NewClientHandshaker(logger logrus.FieldLogger, serverFactory ServerFactory, cfg Configuration) ClientHandshaker

NewClientHandshaker returns a new client side implementation of the backchannel. The provided logger is used to log multiplexing errors.

func (ClientHandshaker) ClientHandshake

ClientHandshake returns TransportCredentials that perform the client side multiplexing handshake and start the backchannel Server on the established connections. The transport credentials are used to intiliaze the connection prior to the multiplexing.

type Configuration

type Configuration struct {
	// MaximumStreamWindowSizeBytes sets the maximum window size in bytes used for yamux streams.
	// Higher value can increase throughput at the cost of more memory usage.
	MaximumStreamWindowSizeBytes uint32
	// AcceptBacklog sets the maximum number of stream openings in-flight before further openings
	// block.
	AcceptBacklog int
	// StreamCloseTimeout is the maximum time that a stream will allowed to
	// be in a half-closed state when `Close` is called before forcibly
	// closing the connection.
	StreamCloseTimeout time.Duration
}

Configuration sets contains configuration for the backchannel's Yamux session.

func DefaultConfiguration

func DefaultConfiguration() Configuration

DefaultConfiguration returns the default configuration.

type ID

type ID uint64

ID is a monotonically increasing number that uniquely identifies a peer connection.

func GetPeerID

func GetPeerID(ctx context.Context) (ID, error)

GetPeerID gets the ID of the current peer connection.

type Registry

type Registry struct {
	// contains filtered or unexported fields
}

Registry is a thread safe registry for backchannels. It enables accessing the backchannels via a unique ID.

func NewRegistry

func NewRegistry() *Registry

NewRegistry returns a new Registry.

func (*Registry) Backchannel

func (r *Registry) Backchannel(id ID) (*grpc.ClientConn, error)

Backchannel returns a backchannel for the ID. Returns an error if no backchannel is registered for the ID.

func (*Registry) RegisterBackchannel

func (r *Registry) RegisterBackchannel(conn *grpc.ClientConn) ID

RegisterBackchannel registers a new backchannel and returns its unique ID.

func (*Registry) RemoveBackchannel

func (r *Registry) RemoveBackchannel(id ID)

RemoveBackchannel removes a backchannel from the registry.

type Server

type Server interface {
	// Serve starts serving on the listener.
	Serve(net.Listener) error
	// Stops the server and closes all connections.
	Stop()
}

Server is the interface of a backchannel server.

type ServerFactory

type ServerFactory func() Server

ServerFactory returns the server that should serve on the backchannel. Each invocation should return a new server as the servers get stopped when a backchannel closes.

type ServerHandshaker

type ServerHandshaker struct {
	// contains filtered or unexported fields
}

ServerHandshaker implements the server side handshake of the multiplexed connection.

func NewServerHandshaker

func NewServerHandshaker(logger logrus.FieldLogger, reg *Registry, dialOpts []grpc.DialOption) *ServerHandshaker

NewServerHandshaker returns a new server side implementation of the backchannel. The provided TransportCredentials are handshaked prior to initializing the multiplexing session. The Registry is used to store the backchannel connections. DialOptions can be used to set custom dial options for the backchannel connections. They must not contain a dialer or transport credentials as those set by the handshaker.

func (*ServerHandshaker) Handshake

func (s *ServerHandshaker) Handshake(conn net.Conn, authInfo credentials.AuthInfo) (net.Conn, credentials.AuthInfo, error)

Handshake establishes a gRPC ClientConn back to the backchannel client on the other side and stores its ID in the AuthInfo where it can be later accessed by the RPC handlers. gRPC sets an IO timeout on the connection before calling ServerHandshake, so we don't have to handle timeouts separately.

func (*ServerHandshaker) Magic

func (s *ServerHandshaker) Magic() string

Magic is used by listenmux to retrieve the magic string for backchannel connections.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL