Layered architecture

Creating maintainable and testable services.

2023-10-02

Introduction

In this post, we'll peel back the onion of server-side architecture, layer by layer. By studying the presentation, service, and persistence layers, we'll uncover the recipe for building highly maintainable and testable services. The ingredients we savour will empower us to recognise and address poorly abstracted designs in our own services.

(My partner complained that I don't make corny puns in my blog posts any more. Here you go, honey!)

We'll also put theory into practice by looking at a realistic example of a poorly abstracted application. Step by step, we'll rearchitect it into the layered design we've learned. And yes, we'll write tests too.

The accompanying code is available in this GitHub repository. It contains the full source code for each iteration of the application we'll be refactoring, along with a handy tool for interacting with it. It also contains a sandbox application that you can use to follow along with this blog post or to experiment with your own ideas.

Understanding layered architecture

Layered architecture is a design pattern in which an application is organised into distinct, loosely coupled layers. These layers are arranged in horizontal tiers, with each layer depending on the layer below it. Each layer has a specific responsibility, and only interacts with the levels adjacent to it.

This allows us to separate concerns and enforce boundaries between different parts of the application. It's kind of like a microservice architecture, but within a single application and arranged in a stack instead of a graph.

Layered architecture is also known as n-tier architecture, because there could be any number of layers (or tiers) in an application. The exact number of layers and their responsibilities can vary depending on the application's requirements, the tech stack, and the developer's preferences.

Layers of an application onion

We can imagine that the layers of an application are like the layers of an onion.

The core of the onion is the database, which contains the precious data that the entire application is built around. Enveloping the core is the persistence layer, which is responsible for interacting with the database. This layer provides an abstraction over the database, and implements the data access and manipulation logic. Anything that wants to interact with the database must go through the persistence layer first.

Above the persistence layer is the service layer, which is the heart of the application. This layer is responsible for implementing the core domain logic of the application. This includes domain data validation, data transformation (e.g. hashing passwords), and interaction with external systems. Without it, the application wouldn't do anything meaningful.

Above that is the presentation layer, which is the outermost layer. This layer is responsible for receiving requests from the client and sending responses back. In the case of a web server, this layer handles request routing, request data validation, and response data formatting. It handles all direct client interactions, protecting the inner layers from the outside world.

Although they are not part of a distinct layer, models are an important part of a layered architecture. Models are data structures which represent how data is stored in the database and how it is presented to the client. They are a shared abstraction that enables communication between the layers. I like to think of them as the juice that permeates every layer of the onion.

The client is external to the application, so it's not a part of the onion. Instead, it communicates with the application through the outermost layer, i.e. the presentation layer. You can imagine that the client is the skin of the onion, or just someone who likes talking to onions. Whatever, I don't judge.

The service layer is the most important layer, and it's the one that we want to protect from the layers above and below. In my experience, most poorly abstracted applications' domain logic is muddled in the midst of networking and database logic. By carving out the networking and database logic into distinct presentation and persistence layers respectively, the service layer can shine through.

Why does layered architecture matter?

I'm not a prescriptivist. I don't believe that there is a single correct way to structure an application, and I don't believe in following rules for the sake of following rules. That being said, I've found that a layered architecture tends to emerge naturally as I refactor services to solve the problems I encounter.

Applications change over time. It's inevitable. New features launch and existing features change. Businesses grow and pivot. New users are targeted and existing users develop new needs.

It's not realistic to expect that the code we write today will remain relevant tomorrow. This is particularly true for fast-moving teams and companies, especially those that are still searching for product-market fit. As a result, we need to write code that is easy to change when the time comes. It also needs to be easy to test, so that we can feel confident in the changes we make.

Layered architecture helps us achieve these goals. Separating code into distinct, loosely-coupled layers enables us to modify and test each layer independently. This enables us to write unit and integration tests, instead of relying exclusively on end-to-end tests. Organising our code and implementing clear abstractions makes it easier to understand and reason about. Components become more reusable, reducing code duplication. We also gain the ability to swap out entire layers without affecting the rest of the application.

Refactoring into layers

Now that we've understood what the different layers in a layered architecture are, let's refactor a poorly abstracted application into a layered design. We will be using Go for this tutorial because it's very explicit yet still relatively simple. However, the same general guidelines apply to other languages.

For brevity, I will leave out the package imports, health check endpoint, configuration values, database setup, and program entry point. You can see these details in the accompanying code, if you're interested.

This code is not intended to be used in production. It is meant to demonstrate architectural concepts, not to be complete or correct.

You can further abstract a persistence layer by using an object-relational mapper (ORM), but this is a matter of preference. Some people prefer ORMs because they reduce the amount of boilerplate code, but others prefer raw SQL because it's more explicit and predictable. Similarly, you can further abstract a presentation layer by generating networking boilerplate or using a web framework which handles routing and data transformation for you. In the interest of learning, we will be doing this all ourselves.

A realistic example

We will be refactoring a simple REST web service which contains a single endpoint: user registration. The endpoint responds to POST requests containing an email address and password, validates them, hashes the password, and stores the new user in the database. It then returns the new user's ID and email address to the client.

Although this endpoint is simple, its logic contains elements of all the layers we've discussed. User registration is also a common feature in many applications, which makes it a realistic example.

v0: Big ball of mud

First, let's start off with a single file application, main.go. Take your time to read through it.

main.go:

func Run() {
	http.HandleFunc("/api/users/register", RegisterUser)
	fmt.Println("Server listening on", config.ServerAddress)
	log.Fatal(http.ListenAndServe(config.ServerAddress, nil))
}

func RegisterUser(w http.ResponseWriter, r *http.Request) {
	if r.Method != "POST" {
		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
		return
	}

	var requestData map[string]string
	err := json.NewDecoder(r.Body).Decode(&requestData)
	if err != nil {
		http.Error(w, "Invalid request body", http.StatusBadRequest)
		return
	}

	email, emailOk := requestData["email"]
	password, passOk := requestData["password"]
	if !emailOk || !passOk {
		http.Error(w, "Email and password are required", http.StatusBadRequest)
		return
	}

	_, err = mail.ParseAddress(email)
	if err != nil {
		http.Error(w, "Invalid email format", http.StatusBadRequest)
		return
	}

	if len(password) < 8 {
		http.Error(w, "Password must be at least 8 characters long", http.StatusBadRequest)
		return
	}

	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		log.Printf("Failed to hash password: %v", err)
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}

	db, err := sql.Open("sqlite3", config.DBPath)
	if err != nil {
		log.Printf("Failed to open database: %v", err)
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}
	defer db.Close()

	// Use a prepared statement
	stmt, err := db.Prepare("INSERT INTO users (email, password) VALUES (?, ?)")
	if err != nil {
		log.Printf("Failed to prepare statement: %v", err)
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
	}
	defer stmt.Close()

	_, err = stmt.Exec(email, string(hashedPassword))
	if err != nil {
		log.Printf("Failed to execute statement: %v", err)
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}

	var userID int
	var userEmail string
	err = db.QueryRow("SELECT id, email FROM users WHERE email=$1", email).
		Scan(&userID, &userEmail)
	if err != nil {
		log.Printf("Failed to query database: %v", err)
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(
		map[string]interface{}{
			"id":    userID,
			"email": userEmail,
		},
	)
}

Although this file is simple to write, it's not very maintainable.

The first thing you'll notice is that the RegisterUser function is relatively long, considering how basic it is. It's hard to understand all the things it does at a glance, how these relate to each other, and what the side effects are. This makes it hard to reason about the code, and it will be hard to make changes to it in the future. As you can imagine, with increasing complexity, this function will only get longer and more difficult to understand.

Another problem is that the function is not very testable. The domain logic is tightly coupled to the networking logic and the database logic. As a result, we can't test any of these components in isolation.

We can still write tests for this function, but only end-to-end tests. This would require spinning up a copy of the application and a copy of the database in a test environment. Each test would make a request to the server, then check the database to see if the request was processed correctly.

This works, but it's not an ideal solution for testing every possible scenario in an application. While end-to-end tests are still useful, they're not a replacement for unit and integration tests.

v1: Separate the presentation layer from the service layer

To start off, let's separate request routing into a new file called router.go. This will allow us to add more endpoints in the future without cluttering main.go, and without repeating overlapping endpoint prefixes. It will also make it easier to add middleware, such as logging or authentication.

I'm using chi, but you can use any router you like.

router.go:

// CreateRouter creates a new router and registers all routes.
func CreateRouter() http.Handler {
	r := chi.NewRouter()

	r.Route("/api", func(r chi.Router) {
		r.Route("/users", func(r chi.Router) {
			r.Post("/register", RegisterUser)
		})
	})

	return r
}

Implementing good routing is about more than just using a router. You want your routes to be sensible and well-organised. You also don't want to create separate routes for every HTTP method or update type. For example, you can share the same route for getting and updating a user, and determine the action based on the HTTP method. You can also have a single route for updating a user, rather than separate routes for each field that can be updated.

Next, let's use the newly created router in Run.

main.go:

func Run() {
	router := CreateRouter()
	fmt.Println("Server listening on", config.ServerAddress)
	log.Fatal(http.ListenAndServe(config.ServerAddress, router))
}

We no longer need to validate the request method because it's being handled by the router.

main.go:

// RegisterUser handles user registration requests.
func RegisterUser(w http.ResponseWriter, r *http.Request) {
	var requestData map[string]string
	// ...

Next, let's create a dedicated handler for the endpoint. This will allow us to separate the HTTP handling from the rest of the endpoint logic.

Create handler.go, move the RegisterUser function into it, and rename the function to RegisterUserHandler. Next, create service.go with a new RegisterUser function.

This time, RegisterUser will take the request data as arguments, and return the response data as a result. Keep the function body the same as before, but remove the HTTP request and response logic, and adjust the return values.

service.go:

// RegisterUser registers a new user in the database and returns the user's data.
func RegisterUser(email string, password string) (int, string, error) {
	if len(password) < 8 {
		return 0, "", errors.New("password must be at least 8 characters long")
	}

	hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		log.Printf("Failed to hash password: %v", err)
		return 0, "", err
	}

	db, err := sql.Open("sqlite3", config.DBPath)
	if err != nil {
		log.Printf("Failed to open database: %v", err)
		return 0, "", err
	}
	defer db.Close()

	// Use a prepared statement
	stmt, err := db.Prepare("INSERT INTO users (email, password) VALUES (?, ?)")
	if err != nil {
		log.Printf("Failed to prepare statement: %v", err)
		return 0, "", err
	}
	defer stmt.Close()

	_, err = stmt.Exec(email, hashedPassword)
	if err != nil {
		log.Printf("Failed to execute statement: %v", err)
		return 0, "", err
	}

	var userID int
	var userEmail string
	err = db.QueryRow("SELECT id, email FROM users WHERE email=?", email).
		Scan(&userID, &userEmail)
	if err != nil {
		log.Printf("Failed to query database: %v", err)
		return 0, "", err
	}

	return userID, userEmail, nil
}

Now remove the non-HTTP logic from RegisterUserHandler and call the new RegisterUser function instead.

handler.go:

// RegisterUserHandler handles requests to register a new user.
func RegisterUserHandler(w http.ResponseWriter, r *http.Request) {
	var requestData map[string]string
	err := json.NewDecoder(r.Body).Decode(&requestData)
	if err != nil {
		http.Error(w, "Invalid request body", http.StatusBadRequest)
		return
	}

	email, emailOk := requestData["email"]
	password, passOk := requestData["password"]
	if !emailOk || !passOk {
		http.Error(w, "Email and password are required", http.StatusBadRequest)
		return
	}

	_, err = mail.ParseAddress(email)
	if err != nil {
		http.Error(w, "Invalid email format", http.StatusBadRequest)
		return
	}

	userID, userEmail, err := RegisterUser(email, password)
	if err != nil {
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}

	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(
		map[string]interface{}{
			"id":    userID,
			"email": userEmail,
		},
	)
}

You may be wondering why the email address validation is in the presentation layer, but the password length validation is in the service layer. This is because the presentation layer handles validation that is about the shape or format of the request data, while the service layer handles validation that is about domain logic. For example, the presentation layer already validates that the password is non-empty, and the service layer could validate that the email address is not already in use.

However, you may have noticed that all errors returned by the service layer, including the password validation error, are currently being returned to the client as an internal server error. This is because we haven't implemented application errors yet. Application errors are specific to the application's domain logic, i.e. the service layer, and are not related to the HTTP protocol.

Create error.go with a new ValidationError type which implements the error interface.

error.go:

type ValidationError struct {
	Field   string
	Message string
}

func (e *ValidationError) Error() string {
	return e.Message
}

We've given our error type two fields. The Field field (haha) will enable us to identify which field the error is related to, which will be useful for our tests. The Message field will be used to return a human-readable error message to the user when the Error method is called.

If you want, you can instead return an error message that includes the field name and/or the error type. I chose to keep it simple so that from the user's perspective, error messages are in the same format whether they originate from the service or presentation layer.

Next, update RegisterUser to return a ValidationError if password validation fails.

service.go:

// ...
	if len(password) < 8 {
		return 0, "", &ValidationError{
			Field:   "password",
			Message: "Password must be at least 8 characters long",
		}
	}
	// ...

Finally, update RegisterUserHandler to handle ValidationError errors.

handler.go:

	// ...
	userID, userEmail, err := RegisterUser(email, password)
	if err != nil {
		// Handle validation errors
		var validationErr *ValidationError
		if errors.As(err, &validationErr) {
			http.Error(w, validationErr.Error(), http.StatusBadRequest)
			return
		}
		// Handle other errors
		http.Error(w, "Internal Server Error", http.StatusInternalServerError)
		return
	}
	// ...

v2: Implement models

Instead of using maps to represent request and response data, let's create dedicated models for them in a new file called message.go.

We'll call these "message models" because they represent the messages that are sent between the client and the server. The key thing to remember is that message models should only contain the fields that are relevant to the client.

message.go:

// RegisterUserRequest represents the expected format for a user registration request.
type RegisterUserRequest struct {
	Email    string `json:"email"`
	Password string `json:"password"`
}

// RegisterUserResponse represents the response format after successfully registering a user.
type RegisterUserResponse struct {
	ID    int    `json:"id"`
	Email string `json:"email"`
}

In Go, we can use json tags to specify the field names in their JSON representation. We also specify the field types with consideration of JSON encoding and decoding. (Encoding is also known as serialisation or marshalling.) These types don't necessarily have to be the same as the database types.

Next, let's update RegisterUserHandler to use the new message models.

Remember to update the references to the model fields as well. For example, requestData["email"] should become req.Email. We can also replace the emailOk and passOk variables with simple empty string checks.

handler.go:

	// ...
	var req RegisterUserRequest
	err := json.NewDecoder(r.Body).Decode(&req)
	if err != nil {
		http.Error(w, "Invalid request body", http.StatusBadRequest)
		return
	}

	if req.Email == "" || req.Password == "" {
		http.Error(w, "Email and password are required", http.StatusBadRequest)
		return
	}
	// ...
	resp := RegisterUserResponse{
		ID:    userID,
		Email: userEmail,
	}
	w.Header().Set("Content-Type", "application/json")
	json.NewEncoder(w).Encode(resp)
}

Similarly, let's create dedicated models for the database in a new file called model.go. We'll call these "data models". We need a model for the users table's schema, and ideally a model for our query parameters.

model.go:

// User represents a user in the database.
type User struct {
	ID        int       `db:"id"`
	Email     string    `db:"email"`
	Password  string    `db:"password"`
	CreatedAt time.Time `db:"created_at"`
}

// CreateUserParams represents the expected fields when creating a user.
type CreateUserParams struct {
	Email    string `db:"email"`
	Password string `db:"password"`
}

The db tags are similar to the json tags, but they specify the field names in the database schema.

Notice something? The User model has a CreatedAt field, which we haven't seen in this application so far. This is because the data model contains all the fields that are stored in the database, even if they're not relevant to the client for a specific endpoint. This distinction will become useful as our application becomes more complex, and the differences between message models and data models grow.

Next, let's update RegisterUser to use the new data models.

service.go:

	// ...
	userToCreate := CreateUserParams{
		Email:    email,
		Password: string(hashedPassword),
	}

	_, err = stmt.Exec(userToCreate.Email, userToCreate.Password)
	if err != nil {
		log.Printf("Failed to execute statement: %v", err)
		return 0, "", err
	}

	var userFromDB User
	err = db.QueryRow("SELECT id, email FROM users WHERE email=?", email).
		Scan(&userFromDB.ID, &userFromDB.Email)
	if err != nil {
		log.Printf("Failed to query database: %v", err)
		return 0, "", err
	}

	return userFromDB.ID, userFromDB.Email, nil
}

Notice that we're not selecting all the fields from the database, even though the User model can represent all of them. That's because we're currently writing a bespoke database query for each endpoint, and this endpoint only needs the id and email fields.

But now that we have separate data and message models, the service layer doesn't need to know what subset of the data models the presentation layer requires. Instead, it can return the entire data model, allowing the presentation layer to decide what data should be presented to the client. This decouples the service layer from the needs of the presentation layer, improving maintainability and reusability as our application becomes more complex.

Update RegisterUser to query for and return the entire data model. Remember to update the return values.

service.go:

func RegisterUser(email string, password string) (*User, error) {
	// ...
	var userFromDB User
	err = db.QueryRow("SELECT * FROM users WHERE email=?", email).Scan(
		&userFromDB.ID,
		&userFromDB.Email,
		&userFromDB.Password,
		&userFromDB.CreatedAt,
	)
	if err != nil {
		log.Printf("Failed to query database: %v", err)
		return nil, err
	}

	return &userFromDB, nil

Next, update RegisterUserHandler to use the new data model.

handler.go:

	// ...
	user, err := RegisterUser(req.Email, req.Password)
	if err != nil {
		// ...
	}

	resp := RegisterUserResponse{
		ID:    user.ID,
		Email: user.Email,
	}
	// ...

v3: Separate the persistence layer from the service layer

Next, let's separate the database logic into a repository, which abstracts away the database implementation details. This will allow us to centralise and make changes to our database logic without affecting the rest of our application. "Repository" is the common name for this pattern, but you can call the type "Store", "Queries", or whatever you'd like.

You can use an ORM if you want, but I will be writing raw SQL to keep the code agnostic.

Create repository.go with a new Repository type which encapsulates the database connection. Implement functions for creating a new Repository and closing the connection.

repository.go:

// Repository represents an interface to the database.
type Repository struct {
	db *sql.DB
}

func NewRepository(databasePath string) (*Repository, error) {
	db, err := sql.Open("sqlite3", databasePath)
	if err != nil {
		return nil, fmt.Errorf("failed to open database: %w", err)
	}

	repo := &Repository{
		db: db,
	}
	return repo, nil
}

func (r *Repository) Close() error {
	return r.db.Close()
}

Next, implement both SQL queries as methods on the Repository type. Use the connection stored in the Repository to execute the queries. Since the same connection is used across queries, we don't need to keep opening and closing it.

repository.go:

const createUserSQL = `
INSERT INTO users (email, password)
VALUES (?, ?);
`

// CreateUser creates a new user in the database.
func (r *Repository) CreateUser(user *CreateUserParams) error {
	// Use a prepared statement
	stmt, err := r.db.Prepare(createUserSQL)
	if err != nil {
		return fmt.Errorf("failed to prepare statement: %w", err)
	}
	defer stmt.Close()

	_, err = stmt.Exec(user.Email, user.Password)
	if err != nil {
		return fmt.Errorf("failed to execute statement: %w", err)
	}

	return nil
}

const getUserByEmailSQL = `
SELECT
	id,
	email,
	password,
	created_at
FROM users
WHERE email=?;
`

// GetUserByEmail returns a user from the database by email.
func (r *Repository) GetUserByEmail(email string) (*User, error) {
	var user User
	err := r.db.QueryRow(getUserByEmailSQL, email).Scan(
		&user.ID,
		&user.Email,
		&user.Password,
		&user.CreatedAt,
	)
	if err != nil {
		return nil, fmt.Errorf("failed to query database: %w", err)
	}

	return &user, nil
}

In order to use the new Repository type, it needs to be accessible to the RegisterUser function. As a result, we also need to create a new Service type which encapsulates the Repository and provides RegisterUser as a method.

service.go:

type Service struct {
	repo *Repository
}

func NewService(repo *Repository) *Service {
	return &Service{
		repo: repo,
	}
}

Next, update RegisterUser to use the new CreateUser and GetUserByEmail methods on the Repository type via the Service type.

service.go:

func (s *Service) RegisterUser(email string, password string) (*User, error) {
	// ...
	userToCreate := &CreateUserParams{
		Email:    email,
		Password: string(hashedPassword),
	}

	err = s.repo.CreateUser(userToCreate)
	if err != nil {
		log.Printf("Failed to create user: %v", err)
		return nil, err
	}

	userFromDB, err := s.repo.GetUserByEmail(email)
	if err != nil {
		log.Printf("Failed to get user: %v", err)
		return nil, err
	}

	return userFromDB, nil
}

Notice that the additional layer of abstraction enables us to log higher-level errors. Instead of redundantly logging errors in every layer, we wrap the persistence layer errors with context and log them in the service layer. Although we could wrap the service layer errors too and log them in the presentation layer, it's possible that service layer functions could be called without an HTTP request, like from a background job or CLI tool.

You may have noticed by now that we don't actually need GetUserByEmail in our specific use case. We could simply make CreateUser return the newly created user instead. However, I chose to include it for demonstration purposes.

Next, we need to update RegisterUserHandler to have access to Service, since RegisterUser is now a method on Service. To do this, we need to turn it from an http.HandlerFunc into a regular function which returns an http.HandlerFunc.

handler.go:

func RegisterUserHandler(svc *Service) http.HandlerFunc {
	return func(w http.ResponseWriter, r *http.Request) {
		// ...
		userID, userEmail, err := svc.RegisterUser(req.Email, req.Password)
	    // ...
}

We also need to update CreateRouter to pass the Service to RegisterUserHandler.

router.go:

// CreateRouter creates a new router and registers all routes.
func CreateRouter(svc *Service) http.Handler {
	r := chi.NewRouter()

	r.Route("/api", func(r chi.Router) {
		r.Route("/users", func(r chi.Router) {
			r.Post("/register", RegisterUserHandler(svc))
		})
	})

	return r
}

Finally, we need to update Run to create the Repository and Service and pass them to CreateRouter.

main.go:

func Run() {
	repo, err := NewRepository(config.DBPath)
	if err != nil {
		log.Fatalf("Failed to create repository: %v", err)
	}
	defer repo.Close()

	svc := NewService(repo)
	router := CreateRouter(svc)

	fmt.Println("Server listening on", config.ServerAddress)
	log.Fatal(http.ListenAndServe(config.ServerAddress, router))
}

v4: Test our onion layers

We have now successfully defined all the layers in our application! However, we're not done yet. We're still limited in our ability to test our new layers.

This is because we've defined our Repository and Service types as concrete types. As a result, we can't mock them out in our tests. In simple language, this means that we can't write tests which replace the real Repository and Service with fake ones that we control.

This leaves us with writing end-to-end tests, which is no better than where we started. (At least, not from a testing perspective.)

To solve this, we need to turn our concrete types into interfaces. We can then use these interfaces as blueprints to build new concrete types, which we can mock out in our tests. This is called dependency injection.

Let's start with the Repository type. Remember to update NewRepository to return the new interface type. You will also have to make other small type changes.

repository.go:

// Repository represents an interface to the database.
type Repository interface {
	CreateUser(user *CreateUserParams) error
	GetUserByEmail(email string) (*User, error)
	Close() error
}

// SQLiteRepository is a repository for SQLite databases.
type SQLiteRepository struct {
	db *sql.DB
}

func NewRepository(databasePath string) (Repository, error) {
	// ...
	repo := &SQLiteRepository{
		db: db,
	}
	return repo, nil
}

func (r *SQLiteRepository) Close() error {
	return r.db.Close()
}

Next, let's do the same for the Service type.

service.go:

type Service interface {
	RegisterUser(email string, password string) (*User, error)
}

type UserService struct {
	repo Repository
}

func NewService(repo Repository) Service {
	return &UserService{
		repo: repo,
	}
}

Now, SQLiteRepository and UserService are concrete types which implement the Repository and Service interfaces respectively.

You might notice that we chose a technology-centric name for the concrete Repository type, yet a domain-centric name for the concrete Service type. This stems from the different dimensions of abstraction for each type. If we wanted to introduce support for another type of database or domain, we might create PostgresRepository or OrderService respectively.

We can finally write some tests! I recommend using a mocking package for real applications, but I will be writing manual mocks to keep the code agnostic. My mocks will be naive, so please don't mock me. (Sorry, I couldn't resist.)

To test our presentation layer, we need to write unit tests which use a mock service. This is because we want to test our networking logic, not our domain logic.

Create handler_test.go with a MockService type which implements the Service interface.

handler_test.go:

type MockService struct{}

func (m *MockService) RegisterUser(email, password string) (*User, error) {
	user := &User{
		ID:        1,
		Email:     email,
		Password:  password,
		CreatedAt: time.Now(),
	}
	return user, nil
}

Next, create a TestRegisterUserHandler function which tests the RegisterUserHandler function.

handler_test.go:

func TestRegisterUserHandler(t *testing.T) {
	mockSvc := &MockService{}
	handler := RegisterUserHandler(mockSvc)

	body := map[string]string{
		"email":    "test@example.com",
		"password": "securepass123",
	}
	bodyBytes, _ := json.Marshal(body)

	req, err := http.NewRequest("POST", "/api/users/register", bytes.NewReader(bodyBytes))
	if err != nil {
		t.Fatalf("Failed to create request: %v", err)
	}

	recorder := httptest.NewRecorder()
	handler.ServeHTTP(recorder, req)

	if status := recorder.Code; status != http.StatusOK {
		t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
	}

	expected := map[string]interface{}{
		"id":    float64(1), // JSON numbers are parsed as floats
		"email": "test@example.com",
	}

	var actual map[string]interface{}
	err = json.NewDecoder(recorder.Body).Decode(&actual)
	if err != nil {
		t.Fatalf("Failed to decode response: %v", err)
	}

	if actual["id"] != expected["id"] {
		t.Errorf("Handler returned unexpected ID: got %v want %v", actual["id"], expected["id"])
	}

	if actual["email"] != expected["email"] {
		t.Errorf("Handler returned unexpected email: got %v want %v", actual["email"], expected["email"])
	}
}

Let's also test all the bad request cases. This includes invalid request bodies, and also that we're using the email validation package correctly.

handler_test.go:

func TestRegisterUserHandler_InvalidInput(t *testing.T) {
	testCases := []struct {
		name string
		body map[string]string
	}{
		{
			name: "invalid body",
			body: map[string]string{
				"test": "test",
			},
		},
		{
			name: "missing email",
			body: map[string]string{
				"password": "securepass123",
			},
		},
		{
			name: "missing password",
			body: map[string]string{
				"email": "test@example.com",
			},
		},
		{
			name: "invalid email",
			body: map[string]string{
				"email":    "test",
				"password": "securepass123",
			},
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			mockSvc := &MockService{}
			handler := RegisterUserHandler(mockSvc)

			bodyBytes, _ := json.Marshal(tc.body)

			req, err := http.NewRequest("POST", "/api/users/register", bytes.NewReader(bodyBytes))
			if err != nil {
				t.Fatalf("Failed to create request: %v", err)
			}

			recorder := httptest.NewRecorder()
			handler.ServeHTTP(recorder, req)

			status := recorder.Code
			if status != http.StatusBadRequest {
				t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusBadRequest)
			}
		})
	}
}

To test our service layer, we need to write unit tests which use a mock repository. This is because we want to test our domain logic, not our database logic.

Create service_test.go with a MockRepository type which implements the Repository interface.

service_test.go:

type MockRepository struct{}

func (m *MockRepository) CreateUser(user *CreateUserParams) error {
	return nil
}

func (m *MockRepository) GetUserByEmail(email string) (*User, error) {
	mockUser := User{
		ID:        1,
		Email:     email,
		Password:  "securepass123",
		CreatedAt: time.Now(),
	}
	return &mockUser, nil
}

func (m *MockRepository) Close() error {
	return nil
}

Next, create a TestRegisterUser function which tests the RegisterUser method.

service_test.go:

func TestRegisterUser(t *testing.T) {
	mockRepo := &MockRepository{}
	svc := NewService(mockRepo)

	email := "test@example.com"
	password := "securepass123"

	user, err := svc.RegisterUser(email, password)
	if err != nil {
		t.Fatalf("Failed to register user: %v", err)
	}

	if user.Email != email {
		t.Errorf("Service returned unexpected email: got %v want %v", user.Email, email)
	}
}

Let's also test our validation logic. This is where ValidationError.Field comes in handy.

service_test.go:

func TestRegisterUser_InvalidInput(t *testing.T) {
	testCases := []struct {
		name     string
		field    string
		email    string
		password string
	}{
		{
			name:     "short password",
			field:    "password",
			email:    "test@example.com",
			password: "short",
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			mockRepo := &MockRepository{}
			svc := NewService(mockRepo)

			_, err := svc.RegisterUser(tc.email, tc.password)
			var validationErr *ValidationError
			if errors.As(err, &validationErr) {
				if validationErr.Field != tc.field {
					t.Errorf("Service returned unexpected validation error field: got %v want %v", validationErr.Field, tc.field)
				}
			} else {
				t.Errorf("Service did not return validation error")
			}
		})
	}
}

You might have noticed that TestRegisterUser does not test that the password is hashed. We can assume that the hashing package works as intended, but we should still test that we're using it correctly.

The problem is our mock GetUserByEmail returns a hard-coded password, rather than the actual password that was provided to the mock CreateUser. The hard-coded password is unhashed, but presumably RegisterUser would have hashed the test password before calling CreateUser.

To solve this, we need to update MockRepository to actually save the user in CreateUser and to retrieve the previously saved user in GetUserByEmail. This mimics the behaviour of a real database. We'll do this by storing the users in a map, where the key is the user's email address.

service_test.go:

type MockRepository struct {
	Users map[string]*User
}

func (m *MockRepository) CreateUser(user *CreateUserParams) error {
	mockUser := &User{
		ID:        1,
		Email:     user.Email,
		Password:  user.Password,
		CreatedAt: time.Now(),
	}

	if m.Users == nil {
		m.Users = make(map[string]*User)
	}
	m.Users[mockUser.Email] = mockUser

	return nil
}

func (m *MockRepository) GetUserByEmail(email string) (*User, error) {
	mockUser, exists := m.Users[email]
	if !exists {
		return nil, errors.New("user not found")
	}

	return mockUser, nil
}

Now we can update TestRegisterUser to check that the password is being hashed.

service_test.go:

	// ...
	if user.Password == password {
		t.Errorf("Service returned unhashed password")
	}
}

Many mocking libraries allow you to set "expectations" on the mocked methods. For example, if we were using gomock, we'd be able to set an expectation that CreateUser is called exactly once, and with a different password than the original password.

mockRepo.EXPECT().CreateUser(gomock.Any()).
    Do(func(user *CreateUserParams) {
        if user.Password == password {
            t.Error("Password was not hashed before creating user")
        }
    }).Times(1)

To test our persistence layer, we need to write integration tests which use a real database. This is because we want to test that our queries are correct, and that they work with the database. Therefore, we can't mock out a fake database.

Create repository_test.go with a NewTestRepository function which creates a new SQLiteRepository using a temporary database.

repository_test.go:

func NewTestRepository() (Repository, error) {
	db.ResetDB(config.TestDBPath)
	return NewRepository(config.TestDBPath)
}

Next, create a TestUsersRepository function which tests the CreateUser and GetUserByEmail methods to ensure that users are created and retrieved correctly.

repository_test.go:

func TestUsersRepository(t *testing.T) {
	testCases := []struct {
		name  string
		email string
	}{
		{
			name:  "first user",
			email: "test1@example.com",
		},
		{
			name:  "second user",
			email: "test2@example.com",
		},
	}

	repo, err := NewTestRepository()
	if err != nil {
		t.Fatalf("Failed to create repository: %v", err)
	}

	// Defers are executed in LIFO order
	defer os.Remove(config.TestDBPath)
	defer repo.Close()

	// Keep track of unique IDs returned by the repository
	uniqueIDs := make(map[int]bool)

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			userToCreate := &CreateUserParams{
				Email:    tc.email,
				Password: "securepass123",
			}

			err = repo.CreateUser(userToCreate)
			if err != nil {
				t.Fatalf("Failed to create user: %v", err)
			}

			err = repo.CreateUser(userToCreate)
			if err == nil {
				t.Fatalf("Repository created user with duplicate email")
			}

			userFromDB, err := repo.GetUserByEmail(userToCreate.Email)
			if err != nil {
				t.Fatalf("Failed to get user: %v", err)
			}

			if userFromDB.Email != userToCreate.Email {
				t.Fatalf("Repository returned unexpected email: got %v want %v", userFromDB.Email, userToCreate.Email)
			}

			uniqueIDs[userFromDB.ID] = true
		})
	}

	if len(uniqueIDs) != len(testCases) {
		t.Fatalf("Repository returned duplicate user IDs")
	}
}

Now we can run our tests and rejoice at them all passing!

=== RUN   TestRegisterUserHandler
--- PASS: TestRegisterUserHandler (0.00s)
=== RUN   TestRegisterUserHandler_InvalidInput
=== RUN   TestRegisterUserHandler_InvalidInput/invalid_body
=== RUN   TestRegisterUserHandler_InvalidInput/missing_email
=== RUN   TestRegisterUserHandler_InvalidInput/missing_password
=== RUN   TestRegisterUserHandler_InvalidInput/invalid_email
--- PASS: TestRegisterUserHandler_InvalidInput (0.00s)
    --- PASS: TestRegisterUserHandler_InvalidInput/invalid_body (0.00s)
    --- PASS: TestRegisterUserHandler_InvalidInput/missing_email (0.00s)
    --- PASS: TestRegisterUserHandler_InvalidInput/missing_password (0.00s)
    --- PASS: TestRegisterUserHandler_InvalidInput/invalid_email (0.00s)
=== RUN   TestUsersRepository
=== RUN   TestUsersRepository/first_user
=== RUN   TestUsersRepository/second_user
--- PASS: TestUsersRepository (0.00s)
    --- PASS: TestUsersRepository/first_user (0.00s)
    --- PASS: TestUsersRepository/second_user (0.00s)
=== RUN   TestRegisterUser
--- PASS: TestRegisterUser (0.06s)
=== RUN   TestRegisterUser_InvalidInput
=== RUN   TestRegisterUser_InvalidInput/short_password
--- PASS: TestRegisterUser_InvalidInput (0.00s)
    --- PASS: TestRegisterUser_InvalidInput/short_password (0.00s)
PASS
ok      layer/versions/v4       0.168s

Conclusion

Congratulations, you've made it to the end! In this post, we learned what a layered architecture is and how it can improve the maintainability and testability of our code. We also learned how to refactor an existing application into a layered architecture, and how to write unit and integration tests. I hope you've come away with ideas on how to improve your own applications, and with increased confidence in your ability to refactor code.