Creating maintainable and testable services.
2023-10-02
In this post, we'll peel back the onion of server-side architecture, layer by layer. By studying the presentation, service, and persistence layers, we'll learn how to build highly maintainable and testable services. The techniques we learn will enable us to recognise and address poorly abstracted designs in our own services.
We'll also put theory into practice by looking at a realistic example of a poorly abstracted application. Step by step, we'll rearchitect it into the layered design we've learned. And yes, we'll write tests too.
The accompanying code is available in this GitHub repository. It contains the full source code for each iteration of the application we'll be refactoring, along with a handy tool for interacting with it. It also contains a sandbox application that you can use to follow along with this blog post or to experiment with your own ideas.
Layered architecture is a design pattern in which an application is organised into distinct, loosely coupled layers. These layers are arranged in horizontal tiers, with each layer depending on the layer below it. Each layer has a specific responsibility, and only interacts with the levels adjacent to it.
This allows us to separate concerns and enforce boundaries between different parts of the application. It's kind of like a microservice architecture, but within a single application and arranged in a stack instead of a graph.
Layered architecture is also known as n-tier architecture, because there could be any number of layers (or tiers) in an application. The exact number of layers and their responsibilities can vary depending on the application's requirements, the tech stack, and the developer's preferences.
We can imagine that the layers of an application are like the layers of an onion.
The core of the onion is the database, which contains the precious data that the entire application is built around. Enveloping the core is the persistence layer, which is responsible for interacting with the database. This layer provides an abstraction over the database, and implements the data access and manipulation logic. Anything that wants to interact with the database must go through the persistence layer first.
Above the persistence layer is the service layer, which is the heart of the application. This layer is responsible for implementing the core domain logic of the application. This includes domain data validation, data transformation (e.g. hashing passwords), and interaction with external systems. Without it, the application wouldn't do anything meaningful.
Above that is the presentation layer, which is the outermost layer. This layer is responsible for receiving requests from the client and sending responses back. In the case of a web server, this layer handles request routing, request data validation, and response data formatting. It handles all direct client interactions, protecting the inner layers from the outside world.
Although they are not part of a distinct layer, models are an important part of a layered architecture. Models are data structures which represent how data is stored in the database and how it is presented to the client. They are a shared abstraction that enables communication between the layers. I like to think of them as the juice that permeates every layer of the onion.
The client is external to the application, so it's not a part of the onion. Instead, it communicates with the application through the outermost layer, i.e. the presentation layer. You can imagine that the client is the skin of the onion, or just someone who likes talking to onions. Whatever, I don't judge.
The service layer is the most important layer, and it's the one that we want to protect from the layers above and below. In my experience, most poorly abstracted applications' domain logic is muddled in the midst of networking and database logic. By carving out the networking and database logic into distinct presentation and persistence layers respectively, the service layer can shine through.
I'm not a prescriptivist. I don't believe that there is a single correct way to structure an application, and I don't believe in following rules for the sake of following rules. That being said, I've found that a layered architecture tends to emerge naturally as I refactor services to solve the problems I encounter.
Applications change over time. It's inevitable. New features launch and existing features change. Businesses grow and pivot. New users are targeted and existing users develop new needs.
It's not realistic to expect that the code we write today will remain relevant tomorrow. This is particularly true for fast-moving teams and companies, especially those that are still searching for product-market fit. As a result, we need to write code that is easy to change when the time comes. It also needs to be easy to test, so that we can feel confident in the changes we make.
Layered architecture helps us achieve these goals. Separating code into distinct, loosely-coupled layers enables us to modify and test each layer independently. This enables us to write unit and integration tests, instead of relying exclusively on end-to-end tests. Organising our code and implementing clear abstractions makes it easier to understand and reason about. Components become more reusable, reducing code duplication. We also gain the ability to swap out entire layers without affecting the rest of the application.
Now that we've understood what the different layers in a layered architecture are, let's refactor a poorly abstracted application into a layered design. We will be using Go for this tutorial because it's very explicit yet still relatively simple. However, the same general guidelines apply to other languages.
For brevity, I will leave out the package imports, health check endpoint, configuration values, database setup, and program entry point. You can see these details in the accompanying code, if you're interested.
This code is not intended to be used in production. It is meant to demonstrate architectural concepts, not to be complete or correct.
You can further abstract a persistence layer by using an object-relational mapper (ORM), but this is a matter of preference. Some people prefer ORMs because they reduce the amount of boilerplate code, but others prefer raw SQL because it's more explicit and predictable. Similarly, you can further abstract a presentation layer by generating networking boilerplate or using a web framework which handles routing and data transformation for you. In the interest of learning, we will be doing this all ourselves.
We will be refactoring a simple REST web service which contains a single endpoint: user registration. The endpoint responds to POST requests containing an email address and password, validates them, hashes the password, and stores the new user in the database. It then returns the new user's ID and email address to the client.
Although this endpoint is simple, its logic contains elements of all the layers we've discussed. User registration is also a common feature in many applications, which makes it a realistic example.
First, let's start off with a single file application, main.go
.
Take your time to read through it.
main.go
:
func Run() {
http.HandleFunc("/api/users/register", RegisterUser)
fmt.Println("Server listening on", config.ServerAddress)
log.Fatal(http.ListenAndServe(config.ServerAddress, nil))
}
func RegisterUser(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var requestData map[string]string
err := json.NewDecoder(r.Body).Decode(&requestData)
if err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
email, emailOk := requestData["email"]
password, passOk := requestData["password"]
if !emailOk || !passOk {
http.Error(w, "Email and password are required", http.StatusBadRequest)
return
}
_, err = mail.ParseAddress(email)
if err != nil {
http.Error(w, "Invalid email format", http.StatusBadRequest)
return
}
if len(password) < 8 {
http.Error(w, "Password must be at least 8 characters long", http.StatusBadRequest)
return
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Printf("Failed to hash password: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
db, err := sql.Open("sqlite3", config.DBPath)
if err != nil {
log.Printf("Failed to open database: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
defer db.Close()
// Use a prepared statement
stmt, err := db.Prepare("INSERT INTO users (email, password) VALUES (?, ?)")
if err != nil {
log.Printf("Failed to prepare statement: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
defer stmt.Close()
_, err = stmt.Exec(email, string(hashedPassword))
if err != nil {
log.Printf("Failed to execute statement: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var userID int
var userEmail string
err = db.QueryRow("SELECT id, email FROM users WHERE email=$1", email).
Scan(&userID, &userEmail)
if err != nil {
log.Printf("Failed to query database: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(
map[string]interface{}{
"id": userID,
"email": userEmail,
},
)
}
Although this file is simple to write, it's not very maintainable.
The first thing you'll notice is that the RegisterUser
function is relatively long, considering how basic it is.
It's hard to understand all the things it does at a glance, how these relate to each other, and what the side effects are.
This makes it hard to reason about the code, and it will be hard to make changes to it in the future.
As you can imagine, with increasing complexity, this function will only get longer and more difficult to understand.
Another problem is that the function is not very testable. The domain logic is tightly coupled to the networking logic and the database logic. As a result, we can't test any of these components in isolation.
We can still write tests for this function, but only end-to-end tests. This would require spinning up a copy of the application and a copy of the database in a test environment. Each test would make a request to the server, then check the database to see if the request was processed correctly.
This works, but it's not an ideal solution for testing every possible scenario in an application. While end-to-end tests are still useful, they're not a replacement for unit and integration tests.
To start off, let's separate request routing into a new file called router.go
.
This will allow us to add more endpoints in the future without cluttering main.go
, and without repeating overlapping endpoint prefixes.
It will also make it easier to add middleware, such as logging or authentication.
I'm using chi, but you can use any router you like.
router.go
:
// CreateRouter creates a new router and registers all routes.
func CreateRouter() http.Handler {
r := chi.NewRouter()
r.Route("/api", func(r chi.Router) {
r.Route("/users", func(r chi.Router) {
r.Post("/register", RegisterUser)
})
})
return r
}
Implementing good routing is about more than just using a router. You want your routes to be sensible and well-organised. You also don't want to create separate routes for every HTTP method or update type. For example, you can share the same route for getting and updating a user, and determine the action based on the HTTP method. You can also have a single route for updating a user, rather than separate routes for each field that can be updated.
Next, let's use the newly created router in Run
.
main.go
:
func Run() {
router := CreateRouter()
fmt.Println("Server listening on", config.ServerAddress)
log.Fatal(http.ListenAndServe(config.ServerAddress, router))
}
We no longer need to validate the request method because it's being handled by the router.
main.go
:
// RegisterUser handles user registration requests.
func RegisterUser(w http.ResponseWriter, r *http.Request) {
var requestData map[string]string
// ...
Next, let's create a dedicated handler for the endpoint. This will allow us to separate the HTTP handling from the rest of the endpoint logic.
Create handler.go
, move the RegisterUser
function into it, and rename the function to RegisterUserHandler
.
Next, create service.go
with a new RegisterUser
function.
This time, RegisterUser
will take the request data as arguments, and return the response data as a result.
Keep the function body the same as before, but remove the HTTP request and response logic, and adjust the return values.
service.go
:
// RegisterUser registers a new user in the database and returns the user's data.
func RegisterUser(email string, password string) (int, string, error) {
if len(password) < 8 {
return 0, "", errors.New("password must be at least 8 characters long")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
log.Printf("Failed to hash password: %v", err)
return 0, "", err
}
db, err := sql.Open("sqlite3", config.DBPath)
if err != nil {
log.Printf("Failed to open database: %v", err)
return 0, "", err
}
defer db.Close()
// Use a prepared statement
stmt, err := db.Prepare("INSERT INTO users (email, password) VALUES (?, ?)")
if err != nil {
log.Printf("Failed to prepare statement: %v", err)
return 0, "", err
}
defer stmt.Close()
_, err = stmt.Exec(email, hashedPassword)
if err != nil {
log.Printf("Failed to execute statement: %v", err)
return 0, "", err
}
var userID int
var userEmail string
err = db.QueryRow("SELECT id, email FROM users WHERE email=?", email).
Scan(&userID, &userEmail)
if err != nil {
log.Printf("Failed to query database: %v", err)
return 0, "", err
}
return userID, userEmail, nil
}
Now remove the non-HTTP logic from RegisterUserHandler
and call the new RegisterUser
function instead.
handler.go
:
// RegisterUserHandler handles requests to register a new user.
func RegisterUserHandler(w http.ResponseWriter, r *http.Request) {
var requestData map[string]string
err := json.NewDecoder(r.Body).Decode(&requestData)
if err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
email, emailOk := requestData["email"]
password, passOk := requestData["password"]
if !emailOk || !passOk {
http.Error(w, "Email and password are required", http.StatusBadRequest)
return
}
_, err = mail.ParseAddress(email)
if err != nil {
http.Error(w, "Invalid email format", http.StatusBadRequest)
return
}
userID, userEmail, err := RegisterUser(email, password)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(
map[string]interface{}{
"id": userID,
"email": userEmail,
},
)
}
You may be wondering why the email address validation is in the presentation layer, but the password length validation is in the service layer. This is because the presentation layer handles validation that is about the shape or format of the request data, while the service layer handles validation that is about domain logic. For example, the presentation layer already validates that the password is non-empty, and the service layer could validate that the email address is not already in use.
However, you may have noticed that all errors returned by the service layer, including the password validation error, are currently being returned to the client as an internal server error. This is because we haven't implemented application errors yet. Application errors are specific to the application's domain logic, i.e. the service layer, and are not related to the HTTP protocol.
Create error.go
with a new ValidationError
type which implements the error
interface.
error.go
:
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return e.Message
}
We've given our error type two fields.
The Field
field (haha) will enable us to identify which field the error is related to, which will be useful for our tests.
The Message
field will be used to return a human-readable error message to the user when the Error
method is called.
If you want, you can instead return an error message that includes the field name and/or the error type. I chose to keep it simple so that from the user's perspective, error messages are in the same format whether they originate from the service or presentation layer.
Next, update RegisterUser
to return a ValidationError
if password validation fails.
service.go
:
// ...
if len(password) < 8 {
return 0, "", &ValidationError{
Field: "password",
Message: "Password must be at least 8 characters long",
}
}
// ...
Finally, update RegisterUserHandler
to handle ValidationError
errors.
handler.go
:
// ...
userID, userEmail, err := RegisterUser(email, password)
if err != nil {
// Handle validation errors
var validationErr *ValidationError
if errors.As(err, &validationErr) {
http.Error(w, validationErr.Error(), http.StatusBadRequest)
return
}
// Handle other errors
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// ...
Instead of using maps to represent request and response data, let's create dedicated models for them in a new file called message.go
.
We'll call these "message models" because they represent the messages that are sent between the client and the server. The key thing to remember is that message models should only contain the fields that are relevant to the client.
message.go
:
// RegisterUserRequest represents the expected format for a user registration request.
type RegisterUserRequest struct {
Email string `json:"email"`
Password string `json:"password"`
}
// RegisterUserResponse represents the response format after successfully registering a user.
type RegisterUserResponse struct {
ID int `json:"id"`
Email string `json:"email"`
}
In Go, we can use json
tags to specify the field names in their JSON representation.
We also specify the field types with consideration of JSON encoding and decoding.
(Encoding is also known as serialisation or marshalling.)
These types don't necessarily have to be the same as the database types.
Next, let's update RegisterUserHandler
to use the new message models.
Remember to update the references to the model fields as well.
For example, requestData["email"]
should become req.Email
.
We can also replace the emailOk
and passOk
variables with simple empty string checks.
handler.go
:
// ...
var req RegisterUserRequest
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, "Invalid request body", http.StatusBadRequest)
return
}
if req.Email == "" || req.Password == "" {
http.Error(w, "Email and password are required", http.StatusBadRequest)
return
}
// ...
resp := RegisterUserResponse{
ID: userID,
Email: userEmail,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
Similarly, let's create dedicated models for the database in a new file called model.go
.
We'll call these "data models".
We need a model for the users
table's schema, and ideally a model for our query parameters.
model.go
:
// User represents a user in the database.
type User struct {
ID int `db:"id"`
Email string `db:"email"`
Password string `db:"password"`
CreatedAt time.Time `db:"created_at"`
}
// CreateUserParams represents the expected fields when creating a user.
type CreateUserParams struct {
Email string `db:"email"`
Password string `db:"password"`
}
The db
tags are similar to the json
tags, but they specify the field names in the database schema.
Notice something?
The User
model has a CreatedAt
field, which we haven't seen in this application so far.
This is because the data model contains all the fields that are stored in the database, even if they're not relevant to the client for a specific endpoint.
This distinction will become useful as our application becomes more complex, and the differences between message models and data models grow.
Next, let's update RegisterUser
to use the new data models.
service.go
:
// ...
userToCreate := CreateUserParams{
Email: email,
Password: string(hashedPassword),
}
_, err = stmt.Exec(userToCreate.Email, userToCreate.Password)
if err != nil {
log.Printf("Failed to execute statement: %v", err)
return 0, "", err
}
var userFromDB User
err = db.QueryRow("SELECT id, email FROM users WHERE email=?", email).
Scan(&userFromDB.ID, &userFromDB.Email)
if err != nil {
log.Printf("Failed to query database: %v", err)
return 0, "", err
}
return userFromDB.ID, userFromDB.Email, nil
}
Notice that we're not selecting all the fields from the database, even though the User
model can represent all of them.
That's because we're currently writing a bespoke database query for each endpoint, and this endpoint only needs the id
and email
fields.
But now that we have separate data and message models, the service layer doesn't need to know what subset of the data models the presentation layer requires. Instead, it can return the entire data model, allowing the presentation layer to decide what data should be presented to the client. This decouples the service layer from the needs of the presentation layer, improving maintainability and reusability as our application becomes more complex.
Update RegisterUser
to query for and return the entire data model.
Remember to update the return values.
service.go
:
func RegisterUser(email string, password string) (*User, error) {
// ...
var userFromDB User
err = db.QueryRow("SELECT * FROM users WHERE email=?", email).Scan(
&userFromDB.ID,
&userFromDB.Email,
&userFromDB.Password,
&userFromDB.CreatedAt,
)
if err != nil {
log.Printf("Failed to query database: %v", err)
return nil, err
}
return &userFromDB, nil
Next, update RegisterUserHandler
to use the new data model.
handler.go
:
// ...
user, err := RegisterUser(req.Email, req.Password)
if err != nil {
// ...
}
resp := RegisterUserResponse{
ID: user.ID,
Email: user.Email,
}
// ...
Next, let's separate the database logic into a repository, which abstracts away the database implementation details. This will allow us to centralise and make changes to our database logic without affecting the rest of our application. "Repository" is the common name for this pattern, but you can call the type "Store", "Queries", or whatever you'd like.
You can use an ORM if you want, but I will be writing raw SQL to keep the code agnostic.
Create repository.go
with a new Repository
type which encapsulates the database connection.
Implement functions for creating a new Repository
and closing the connection.
repository.go
:
// Repository represents an interface to the database.
type Repository struct {
db *sql.DB
}
func NewRepository(databasePath string) (*Repository, error) {
db, err := sql.Open("sqlite3", databasePath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
repo := &Repository{
db: db,
}
return repo, nil
}
func (r *Repository) Close() error {
return r.db.Close()
}
Next, implement both SQL queries as methods on the Repository
type.
Use the connection stored in the Repository
to execute the queries.
Since the same connection is used across queries, we don't need to keep opening and closing it.
repository.go
:
const createUserSQL = `
INSERT INTO users (email, password)
VALUES (?, ?);
`
// CreateUser creates a new user in the database.
func (r *Repository) CreateUser(user *CreateUserParams) error {
// Use a prepared statement
stmt, err := r.db.Prepare(createUserSQL)
if err != nil {
return fmt.Errorf("failed to prepare statement: %w", err)
}
defer stmt.Close()
_, err = stmt.Exec(user.Email, user.Password)
if err != nil {
return fmt.Errorf("failed to execute statement: %w", err)
}
return nil
}
const getUserByEmailSQL = `
SELECT
id,
email,
password,
created_at
FROM users
WHERE email=?;
`
// GetUserByEmail returns a user from the database by email.
func (r *Repository) GetUserByEmail(email string) (*User, error) {
var user User
err := r.db.QueryRow(getUserByEmailSQL, email).Scan(
&user.ID,
&user.Email,
&user.Password,
&user.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to query database: %w", err)
}
return &user, nil
}
In order to use the new Repository
type, it needs to be accessible to the RegisterUser
function.
As a result, we also need to create a new Service
type which encapsulates the Repository
and provides RegisterUser
as a method.
service.go
:
type Service struct {
repo *Repository
}
func NewService(repo *Repository) *Service {
return &Service{
repo: repo,
}
}
Next, update RegisterUser
to use the new CreateUser
and GetUserByEmail
methods on the Repository
type via the Service
type.
service.go
:
func (s *Service) RegisterUser(email string, password string) (*User, error) {
// ...
userToCreate := &CreateUserParams{
Email: email,
Password: string(hashedPassword),
}
err = s.repo.CreateUser(userToCreate)
if err != nil {
log.Printf("Failed to create user: %v", err)
return nil, err
}
userFromDB, err := s.repo.GetUserByEmail(email)
if err != nil {
log.Printf("Failed to get user: %v", err)
return nil, err
}
return userFromDB, nil
}
Notice that the additional layer of abstraction enables us to log higher-level errors. Instead of redundantly logging errors in every layer, we wrap the persistence layer errors with context and log them in the service layer. Although we could wrap the service layer errors too and log them in the presentation layer, it's possible that service layer functions could be called without an HTTP request, like from a background job or CLI tool.
You may have noticed by now that we don't actually need GetUserByEmail
in our specific use case.
We could simply make CreateUser
return the newly created user instead.
However, I chose to include it for demonstration purposes.
Next, we need to update RegisterUserHandler
to have access to Service
, since RegisterUser
is now a method on Service
.
To do this, we need to turn it from an http.HandlerFunc
into a regular function which returns an http.HandlerFunc
.
handler.go
:
func RegisterUserHandler(svc *Service) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// ...
userID, userEmail, err := svc.RegisterUser(req.Email, req.Password)
// ...
}
We also need to update CreateRouter
to pass the Service
to RegisterUserHandler
.
router.go
:
// CreateRouter creates a new router and registers all routes.
func CreateRouter(svc *Service) http.Handler {
r := chi.NewRouter()
r.Route("/api", func(r chi.Router) {
r.Route("/users", func(r chi.Router) {
r.Post("/register", RegisterUserHandler(svc))
})
})
return r
}
Finally, we need to update Run
to create the Repository
and Service
and pass them to CreateRouter
.
main.go
:
func Run() {
repo, err := NewRepository(config.DBPath)
if err != nil {
log.Fatalf("Failed to create repository: %v", err)
}
defer repo.Close()
svc := NewService(repo)
router := CreateRouter(svc)
fmt.Println("Server listening on", config.ServerAddress)
log.Fatal(http.ListenAndServe(config.ServerAddress, router))
}
We have now successfully defined all the layers in our application! However, we're not done yet. We're still limited in our ability to test our new layers.
This is because we've defined our Repository
and Service
types as concrete types.
As a result, we can't mock them out in our tests.
In simple language, this means that we can't write tests which replace the real Repository
and Service
with fake ones that we control.
This leaves us with writing end-to-end tests, which is no better than where we started. (At least, not from a testing perspective.)
To solve this, we need to turn our concrete types into interfaces. We can then use these interfaces as blueprints to build new concrete types, which we can mock out in our tests. This is called dependency injection.
Let's start with the Repository
type.
Remember to update NewRepository
to return the new interface type.
You will also have to make other small type changes.
repository.go
:
// Repository represents an interface to the database.
type Repository interface {
CreateUser(user *CreateUserParams) error
GetUserByEmail(email string) (*User, error)
Close() error
}
// SQLiteRepository is a repository for SQLite databases.
type SQLiteRepository struct {
db *sql.DB
}
func NewRepository(databasePath string) (Repository, error) {
// ...
repo := &SQLiteRepository{
db: db,
}
return repo, nil
}
func (r *SQLiteRepository) Close() error {
return r.db.Close()
}
Next, let's do the same for the Service
type.
service.go
:
type Service interface {
RegisterUser(email string, password string) (*User, error)
}
type UserService struct {
repo Repository
}
func NewService(repo Repository) Service {
return &UserService{
repo: repo,
}
}
Now, SQLiteRepository
and UserService
are concrete types which implement the Repository
and Service
interfaces respectively.
You might notice that we chose a technology-centric name for the concrete Repository
type, yet a domain-centric name for the concrete Service
type.
This stems from the different dimensions of abstraction for each type.
If we wanted to introduce support for another type of database or domain, we might create PostgresRepository
or OrderService
respectively.
We can finally write some tests! I recommend using a mocking package for real applications, but I will be writing manual mocks to keep the code agnostic. My mocks will be naive, so please don't mock me. (Sorry, I couldn't resist.)
To test our presentation layer, we need to write unit tests which use a mock service. This is because we want to test our networking logic, not our domain logic.
Create handler_test.go
with a MockService
type which implements the Service
interface.
handler_test.go
:
type MockService struct{}
func (m *MockService) RegisterUser(email, password string) (*User, error) {
user := &User{
ID: 1,
Email: email,
Password: password,
CreatedAt: time.Now(),
}
return user, nil
}
Next, create a TestRegisterUserHandler
function which tests the RegisterUserHandler
function.
handler_test.go
:
func TestRegisterUserHandler(t *testing.T) {
mockSvc := &MockService{}
handler := RegisterUserHandler(mockSvc)
body := map[string]string{
"email": "test@example.com",
"password": "securepass123",
}
bodyBytes, _ := json.Marshal(body)
req, err := http.NewRequest("POST", "/api/users/register", bytes.NewReader(bodyBytes))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
if status := recorder.Code; status != http.StatusOK {
t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusOK)
}
expected := map[string]interface{}{
"id": float64(1), // JSON numbers are parsed as floats
"email": "test@example.com",
}
var actual map[string]interface{}
err = json.NewDecoder(recorder.Body).Decode(&actual)
if err != nil {
t.Fatalf("Failed to decode response: %v", err)
}
if actual["id"] != expected["id"] {
t.Errorf("Handler returned unexpected ID: got %v want %v", actual["id"], expected["id"])
}
if actual["email"] != expected["email"] {
t.Errorf("Handler returned unexpected email: got %v want %v", actual["email"], expected["email"])
}
}
Let's also test all the bad request cases. This includes invalid request bodies, and also that we're using the email validation package correctly.
handler_test.go
:
func TestRegisterUserHandler_InvalidInput(t *testing.T) {
testCases := []struct {
name string
body map[string]string
}{
{
name: "invalid body",
body: map[string]string{
"test": "test",
},
},
{
name: "missing email",
body: map[string]string{
"password": "securepass123",
},
},
{
name: "missing password",
body: map[string]string{
"email": "test@example.com",
},
},
{
name: "invalid email",
body: map[string]string{
"email": "test",
"password": "securepass123",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockSvc := &MockService{}
handler := RegisterUserHandler(mockSvc)
bodyBytes, _ := json.Marshal(tc.body)
req, err := http.NewRequest("POST", "/api/users/register", bytes.NewReader(bodyBytes))
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
status := recorder.Code
if status != http.StatusBadRequest {
t.Errorf("Handler returned wrong status code: got %v want %v", status, http.StatusBadRequest)
}
})
}
}
To test our service layer, we need to write unit tests which use a mock repository. This is because we want to test our domain logic, not our database logic.
Create service_test.go
with a MockRepository
type which implements the Repository
interface.
service_test.go
:
type MockRepository struct{}
func (m *MockRepository) CreateUser(user *CreateUserParams) error {
return nil
}
func (m *MockRepository) GetUserByEmail(email string) (*User, error) {
mockUser := User{
ID: 1,
Email: email,
Password: "securepass123",
CreatedAt: time.Now(),
}
return &mockUser, nil
}
func (m *MockRepository) Close() error {
return nil
}
Next, create a TestRegisterUser
function which tests the RegisterUser
method.
service_test.go
:
func TestRegisterUser(t *testing.T) {
mockRepo := &MockRepository{}
svc := NewService(mockRepo)
email := "test@example.com"
password := "securepass123"
user, err := svc.RegisterUser(email, password)
if err != nil {
t.Fatalf("Failed to register user: %v", err)
}
if user.Email != email {
t.Errorf("Service returned unexpected email: got %v want %v", user.Email, email)
}
}
Let's also test our validation logic.
This is where ValidationError.Field
comes in handy.
service_test.go
:
func TestRegisterUser_InvalidInput(t *testing.T) {
testCases := []struct {
name string
field string
email string
password string
}{
{
name: "short password",
field: "password",
email: "test@example.com",
password: "short",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
mockRepo := &MockRepository{}
svc := NewService(mockRepo)
_, err := svc.RegisterUser(tc.email, tc.password)
var validationErr *ValidationError
if errors.As(err, &validationErr) {
if validationErr.Field != tc.field {
t.Errorf("Service returned unexpected validation error field: got %v want %v", validationErr.Field, tc.field)
}
} else {
t.Errorf("Service did not return validation error")
}
})
}
}
You might have noticed that TestRegisterUser
does not test that the password is hashed.
We can assume that the hashing package works as intended, but we should still test that we're using it correctly.
The problem is our mock GetUserByEmail
returns a hard-coded password, rather than the actual password that was provided to the mock CreateUser
.
The hard-coded password is unhashed, but presumably RegisterUser
would have hashed the test password before calling CreateUser
.
To solve this, we need to update MockRepository
to actually save the user in CreateUser
and to retrieve the previously saved user in GetUserByEmail
.
This mimics the behaviour of a real database.
We'll do this by storing the users in a map, where the key is the user's email address.
service_test.go
:
type MockRepository struct {
Users map[string]*User
}
func (m *MockRepository) CreateUser(user *CreateUserParams) error {
mockUser := &User{
ID: 1,
Email: user.Email,
Password: user.Password,
CreatedAt: time.Now(),
}
if m.Users == nil {
m.Users = make(map[string]*User)
}
m.Users[mockUser.Email] = mockUser
return nil
}
func (m *MockRepository) GetUserByEmail(email string) (*User, error) {
mockUser, exists := m.Users[email]
if !exists {
return nil, errors.New("user not found")
}
return mockUser, nil
}
Now we can update TestRegisterUser
to check that the password is being hashed.
service_test.go
:
// ...
if user.Password == password {
t.Errorf("Service returned unhashed password")
}
}
Many mocking libraries allow you to set "expectations" on the mocked methods.
For example, if we were using gomock, we'd be able to set an expectation that CreateUser
is called exactly once, and with a different password than the original password.
mockRepo.EXPECT().CreateUser(gomock.Any()).
Do(func(user *CreateUserParams) {
if user.Password == password {
t.Error("Password was not hashed before creating user")
}
}).Times(1)
To test our persistence layer, we need to write integration tests which use a real database. This is because we want to test that our queries are correct, and that they work with the database. Therefore, we can't mock out a fake database.
Create repository_test.go
with a NewTestRepository
function which creates a new SQLiteRepository
using a temporary database.
repository_test.go
:
func NewTestRepository() (Repository, error) {
db.ResetDB(config.TestDBPath)
return NewRepository(config.TestDBPath)
}
Next, create a TestUsersRepository
function which tests the CreateUser
and GetUserByEmail
methods to ensure that users are created and retrieved correctly.
repository_test.go
:
func TestUsersRepository(t *testing.T) {
testCases := []struct {
name string
email string
}{
{
name: "first user",
email: "test1@example.com",
},
{
name: "second user",
email: "test2@example.com",
},
}
repo, err := NewTestRepository()
if err != nil {
t.Fatalf("Failed to create repository: %v", err)
}
// Defers are executed in LIFO order
defer os.Remove(config.TestDBPath)
defer repo.Close()
// Keep track of unique IDs returned by the repository
uniqueIDs := make(map[int]bool)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
userToCreate := &CreateUserParams{
Email: tc.email,
Password: "securepass123",
}
err = repo.CreateUser(userToCreate)
if err != nil {
t.Fatalf("Failed to create user: %v", err)
}
err = repo.CreateUser(userToCreate)
if err == nil {
t.Fatalf("Repository created user with duplicate email")
}
userFromDB, err := repo.GetUserByEmail(userToCreate.Email)
if err != nil {
t.Fatalf("Failed to get user: %v", err)
}
if userFromDB.Email != userToCreate.Email {
t.Fatalf("Repository returned unexpected email: got %v want %v", userFromDB.Email, userToCreate.Email)
}
uniqueIDs[userFromDB.ID] = true
})
}
if len(uniqueIDs) != len(testCases) {
t.Fatalf("Repository returned duplicate user IDs")
}
}
Now we can run our tests and rejoice at them all passing!
=== RUN TestRegisterUserHandler
--- PASS: TestRegisterUserHandler (0.00s)
=== RUN TestRegisterUserHandler_InvalidInput
=== RUN TestRegisterUserHandler_InvalidInput/invalid_body
=== RUN TestRegisterUserHandler_InvalidInput/missing_email
=== RUN TestRegisterUserHandler_InvalidInput/missing_password
=== RUN TestRegisterUserHandler_InvalidInput/invalid_email
--- PASS: TestRegisterUserHandler_InvalidInput (0.00s)
--- PASS: TestRegisterUserHandler_InvalidInput/invalid_body (0.00s)
--- PASS: TestRegisterUserHandler_InvalidInput/missing_email (0.00s)
--- PASS: TestRegisterUserHandler_InvalidInput/missing_password (0.00s)
--- PASS: TestRegisterUserHandler_InvalidInput/invalid_email (0.00s)
=== RUN TestUsersRepository
=== RUN TestUsersRepository/first_user
=== RUN TestUsersRepository/second_user
--- PASS: TestUsersRepository (0.00s)
--- PASS: TestUsersRepository/first_user (0.00s)
--- PASS: TestUsersRepository/second_user (0.00s)
=== RUN TestRegisterUser
--- PASS: TestRegisterUser (0.06s)
=== RUN TestRegisterUser_InvalidInput
=== RUN TestRegisterUser_InvalidInput/short_password
--- PASS: TestRegisterUser_InvalidInput (0.00s)
--- PASS: TestRegisterUser_InvalidInput/short_password (0.00s)
PASS
ok layer/versions/v4 0.168s
Congratulations, you've made it to the end! In this post, we learned what a layered architecture is and how it can improve the maintainability and testability of our code. We also learned how to refactor an existing application into a layered architecture, and how to write unit and integration tests. I hope you've come away with ideas on how to improve your own applications, and with increased confidence in your ability to refactor code.