package apigateway
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
v2 "github.com/apache/apisix-ingress-controller/pkg/kube/apisix/apis/config/v2"
"github.com/go-chi/chi"
ctxutil "github.com/goodrain/rainbond/api/util/ctx"
"github.com/goodrain/rainbond/db"
dbdao "github.com/goodrain/rainbond/db/dao"
dbmodel "github.com/goodrain/rainbond/db/model"
"github.com/goodrain/rainbond/pkg/component/k8s"
"github.com/jinzhu/gorm"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
)
type tcpRouteTestManager struct {
db.Manager
tenantServiceDao dbdao.TenantServiceDao
tcpRuleDao dbdao.TCPRuleDao
}
func (m tcpRouteTestManager) TenantServiceDao() dbdao.TenantServiceDao {
return m.tenantServiceDao
}
func (m tcpRouteTestManager) TCPRuleDao() dbdao.TCPRuleDao {
return m.tcpRuleDao
}
type tcpRouteTenantServiceDao struct {
dbdao.TenantServiceDao
servicesByID map[string]*dbmodel.TenantServices
}
func (d *tcpRouteTenantServiceDao) GetServiceByTenantIDAndServiceAlias(tenantID, serviceName string) (*dbmodel.TenantServices, error) {
return nil, gorm.ErrRecordNotFound
}
func (d *tcpRouteTenantServiceDao) GetServiceByID(serviceID string) (*dbmodel.TenantServices, error) {
service, ok := d.servicesByID[serviceID]
if !ok {
return nil, gorm.ErrRecordNotFound
}
return service, nil
}
type tcpRouteRuleDao struct {
dbdao.TCPRuleDao
added *dbmodel.TCPRule
}
func (d *tcpRouteRuleDao) AddModel(m dbmodel.Interface) error {
d.added = m.(*dbmodel.TCPRule)
return nil
}
func newTCPRouteTestClientset(t *testing.T, services map[string]*corev1.Service) (*kubernetes.Clientset, func()) {
t.Helper()
scheme := runtime.NewScheme()
if err := corev1.AddToScheme(scheme); err != nil {
t.Fatalf("add core scheme: %v", err)
}
codecs := serializer.NewCodecFactory(scheme)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.Method == http.MethodGet && r.URL.Path == "/api/v1/namespaces/default/services" {
serviceList := corev1.ServiceList{}
for _, service := range services {
serviceList.Items = append(serviceList.Items, *service)
}
if err := json.NewEncoder(w).Encode(&serviceList); err != nil {
t.Fatalf("encode service list: %v", err)
}
return
}
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, "/api/v1/namespaces/default/services/") {
name := strings.TrimPrefix(r.URL.Path, "/api/v1/namespaces/default/services/")
if service, ok := services[name]; ok {
if err := json.NewEncoder(w).Encode(service); err != nil {
t.Fatalf("encode service: %v", err)
}
return
}
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(v1.Status{
TypeMeta: v1.TypeMeta{
Kind: "Status",
APIVersion: "v1",
},
Status: v1.StatusFailure,
Reason: v1.StatusReasonNotFound,
Code: http.StatusNotFound,
})
return
}
if r.Method == http.MethodPost && r.URL.Path == "/api/v1/namespaces/default/services" {
var service corev1.Service
if err := json.NewDecoder(r.Body).Decode(&service); err != nil {
t.Fatalf("decode service: %v", err)
}
service.Namespace = "default"
services[service.Name] = &service
w.WriteHeader(http.StatusCreated)
if err := json.NewEncoder(w).Encode(&service); err != nil {
t.Fatalf("encode created service: %v", err)
}
return
}
if r.Method == http.MethodDelete && strings.HasPrefix(r.URL.Path, "/api/v1/namespaces/default/services/") {
name := strings.TrimPrefix(r.URL.Path, "/api/v1/namespaces/default/services/")
if _, ok := services[name]; !ok {
w.WriteHeader(http.StatusNotFound)
_ = json.NewEncoder(w).Encode(v1.Status{
TypeMeta: v1.TypeMeta{
Kind: "Status",
APIVersion: "v1",
},
Status: v1.StatusFailure,
Reason: v1.StatusReasonNotFound,
Code: http.StatusNotFound,
})
return
}
delete(services, name)
_ = json.NewEncoder(w).Encode(v1.Status{
TypeMeta: v1.TypeMeta{
Kind: "Status",
APIVersion: "v1",
},
Status: v1.StatusSuccess,
Code: http.StatusOK,
})
return
}
t.Fatalf("unexpected kubernetes request: %s %s", r.Method, r.URL.Path)
}))
config := &rest.Config{
Host: server.URL,
ContentConfig: rest.ContentConfig{
GroupVersion: &schema.GroupVersion{Version: "v1"},
NegotiatedSerializer: codecs.WithoutConversion(),
},
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
t.Fatalf("create clientset: %v", err)
}
return clientset, server.Close
}
func TestParseCertManagerDomains(t *testing.T) {
domains := parseCertManagerDomains("foo.example.com, bar.example.com ,,baz.example.com")
if len(domains) != 3 {
t.Fatalf("expected 3 domains, got %d", len(domains))
}
if domains[0] != "foo.example.com" || domains[1] != "bar.example.com" || domains[2] != "baz.example.com" {
t.Fatalf("unexpected domains: %#v", domains)
}
}
func TestHasMatchingCertManagerDomain(t *testing.T) {
tests := []struct {
name string
certDomains []string
routeDomains []string
want bool
}{
{
name: "exact match",
certDomains: []string{"foo.example.com"},
routeDomains: []string{"foo.example.com"},
want: true,
},
{
name: "wildcard match",
certDomains: []string{"*.example.com"},
routeDomains: []string{"foo.example.com"},
want: false,
},
{
name: "no overlap",
certDomains: []string{"foo.example.com"},
routeDomains: []string{"bar.example.com"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := hasMatchingCertManagerDomain(tt.certDomains, tt.routeDomains)
if got != tt.want {
t.Fatalf("expected %v, got %v", tt.want, got)
}
})
}
}
func TestRouteMatchesCertManagerDomains(t *testing.T) {
route := &v2.ApisixRoute{
Spec: v2.ApisixRouteSpec{
HTTP: []v2.ApisixRouteHTTP{
{
Match: v2.ApisixRouteHTTPMatch{
Hosts: []string{"foo.example.com", "bar.example.com"},
},
},
},
},
}
if !routeMatchesCertManagerDomains(route, []string{"bar.example.com"}) {
t.Fatal("expected route to match certificate domains")
}
if routeMatchesCertManagerDomains(route, []string{"baz.example.com"}) {
t.Fatal("expected route not to match unrelated certificate domains")
}
}
func TestCreateTCPRouteUsesRainbondServiceAliasFromBackendServiceLabels(t *testing.T) {
const (
namespace = "default"
tenantID = "tenant-id"
appID = "app-id"
serviceID = "db66afd0892c326ff557df7880ac572d"
serviceAlias = "grac572d"
serviceName = "demo-2048"
nodePort = int32(30000)
)
services := map[string]*corev1.Service{
serviceName: {
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: namespace,
Labels: map[string]string{
"app_id": appID,
"service_id": serviceID,
"service_alias": serviceAlias,
"rainbond_app": serviceName,
},
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{{
Name: "http-8080",
Port: 8080,
TargetPort: intstr.FromInt(8080),
}},
Selector: map[string]string{"name": serviceAlias},
},
},
}
clientset, closeServer := newTCPRouteTestClientset(t, services)
defer closeServer()
k8s.New().Clientset = clientset
ruleDao := &tcpRouteRuleDao{}
db.SetTestManager(tcpRouteTestManager{
tenantServiceDao: &tcpRouteTenantServiceDao{servicesByID: map[string]*dbmodel.TenantServices{
serviceID: {
ServiceID: serviceID,
ServiceAlias: serviceAlias,
TenantID: tenantID,
ExtendMethod: "",
K8sComponentName: serviceAlias,
},
}},
tcpRuleDao: ruleDao,
})
defer db.SetTestManager(nil)
streamRoute := v2.ApisixRouteStream{
Name: "tcp",
Protocol: "tcp",
Match: v2.ApisixRouteStreamMatch{
IngressPort: nodePort,
},
Backend: v2.ApisixRouteStreamBackend{
ServiceName: serviceName,
ServicePort: intstr.FromInt(8080),
},
}
body, err := json.Marshal(streamRoute)
if err != nil {
t.Fatalf("marshal route: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/?appID="+appID, bytes.NewReader(body))
ctx := context.WithValue(req.Context(), ctxutil.ContextKey("tenant"), &dbmodel.Tenants{
UUID: tenantID,
Namespace: namespace,
})
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
Struct{}.CreateTCPRoute(rr, req)
created, err := k8s.Default().Clientset.CoreV1().Services(namespace).Get(context.Background(), serviceName+"-30000", v1.GetOptions{})
if err != nil {
if errors.IsNotFound(err) {
t.Fatalf("expected NodePort service to be created")
}
t.Fatalf("get created service: %v", err)
}
if got := created.Spec.Selector["service_alias"]; got != serviceAlias {
t.Fatalf("expected selector service_alias %q, got %q", serviceAlias, got)
}
if got := created.Labels["service_alias"]; got != serviceAlias {
t.Fatalf("expected label service_alias %q, got %q", serviceAlias, got)
}
if got := created.Labels["service_id"]; got != serviceID {
t.Fatalf("expected label service_id %q, got %q", serviceID, got)
}
if ruleDao.added == nil {
t.Fatal("expected TCP rule to be persisted")
}
if got := ruleDao.added.ServiceID; got != serviceID {
t.Fatalf("expected TCP rule service_id %q, got %q", serviceID, got)
}
}
func TestGetTCPRouteIncludesServiceMetadata(t *testing.T) {
const (
namespace = "default"
appID = "4d0f77e042f94ae2a77552fe7b595faf"
serviceID = "7de1e7b94ccf418eac0cc0de61447979"
serviceAlias = "gr447979"
serviceName = "gr447979-30003"
nodePort = int32(30003)
)
services := map[string]*corev1.Service{
serviceName: {
ObjectMeta: v1.ObjectMeta{
Name: serviceName,
Namespace: namespace,
Labels: map[string]string{
"app_id": appID,
"service_id": serviceID,
"service_alias": serviceAlias,
"port": "80",
"tcp": "true",
},
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{{
Name: serviceName,
Protocol: corev1.ProtocolTCP,
Port: 80,
TargetPort: intstr.FromInt(80),
NodePort: nodePort,
}},
Selector: map[string]string{"service_alias": serviceAlias},
Type: corev1.ServiceTypeNodePort,
},
},
}
clientset, closeServer := newTCPRouteTestClientset(t, services)
defer closeServer()
k8s.New().Clientset = clientset
req := httptest.NewRequest(http.MethodGet, "/?appID="+appID, nil)
ctx := context.WithValue(req.Context(), ctxutil.ContextKey("tenant"), &dbmodel.Tenants{
Namespace: namespace,
})
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
Struct{}.GetTCPRoute(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", rr.Code, rr.Body.String())
}
var resp struct {
List []struct {
Name string `json:"name"`
Port int32 `json:"port"`
NodePort int32 `json:"nodePort"`
ServiceName string `json:"service_name"`
ServiceAlias string `json:"service_alias"`
ServiceID string `json:"service_id"`
AppID string `json:"app_id"`
ContainerPort int32 `json:"container_port"`
} `json:"list"`
}
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if len(resp.List) != 1 {
t.Fatalf("expected one route, got %#v", resp.List)
}
got := resp.List[0]
if got.Name != serviceName || got.Port != 80 || got.NodePort != nodePort {
t.Fatalf("unexpected service port fields: %#v", got)
}
if got.ServiceName != serviceName {
t.Fatalf("expected service_name %q, got %q", serviceName, got.ServiceName)
}
if got.ServiceAlias != serviceAlias {
t.Fatalf("expected service_alias %q, got %q", serviceAlias, got.ServiceAlias)
}
if got.ServiceID != serviceID {
t.Fatalf("expected service_id %q, got %q", serviceID, got.ServiceID)
}
if got.AppID != appID {
t.Fatalf("expected app_id %q, got %q", appID, got.AppID)
}
if got.ContainerPort != 80 {
t.Fatalf("expected container_port 80, got %d", got.ContainerPort)
}
}
func TestDeleteTCPRouteFallsBackToNodePortWhenNameDiffers(t *testing.T) {
const (
namespace = "default"
serviceAlias = "gr447979"
actualServiceName = "gr447979-30003"
requestedName = "custom-k8s-service-30003"
nodePort = int32(30003)
)
services := map[string]*corev1.Service{
actualServiceName: {
ObjectMeta: v1.ObjectMeta{
Name: actualServiceName,
Namespace: namespace,
Labels: map[string]string{
"service_alias": serviceAlias,
"outer": "true",
"tcp": "true",
"port": "6379",
},
},
Spec: corev1.ServiceSpec{
Ports: []corev1.ServicePort{{
Name: actualServiceName,
Protocol: corev1.ProtocolTCP,
Port: 6379,
TargetPort: intstr.FromInt(6379),
NodePort: nodePort,
}},
Type: corev1.ServiceTypeNodePort,
},
},
}
clientset, closeServer := newTCPRouteTestClientset(t, services)
defer closeServer()
k8s.New().Clientset = clientset
req := httptest.NewRequest(http.MethodDelete, "/"+requestedName, nil)
routeCtx := chi.NewRouteContext()
routeCtx.URLParams.Add("name", requestedName)
ctx := context.WithValue(req.Context(), chi.RouteCtxKey, routeCtx)
ctx = context.WithValue(ctx, ctxutil.ContextKey("tenant"), &dbmodel.Tenants{
Namespace: namespace,
})
req = req.WithContext(ctx)
rr := httptest.NewRecorder()
Struct{}.DeleteTCPRoute(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d: %s", rr.Code, rr.Body.String())
}
_, err := k8s.Default().Clientset.CoreV1().Services(namespace).Get(context.Background(), actualServiceName, v1.GetOptions{})
if !errors.IsNotFound(err) {
t.Fatalf("expected actual TCP route service to be deleted, got err %v", err)
}
}