package embedder import ( "bytes" "context" "encoding/json" "fmt" "net/http" "os" ) // Embedder generates embeddings for text type Embedder interface { Embed(ctx context.Context, texts []string) ([][]float32, error) Dimensions() int } // OllamaEmbedder uses Ollama's embedding API type OllamaEmbedder struct { baseURL string model string dims int } // NewOllamaEmbedder creates an Ollama embedder func NewOllamaEmbedder(model string) *OllamaEmbedder { baseURL := os.Getenv("CODEVEC_BASE_URL") if baseURL == "" { baseURL = "http://localhost:11434" } if model == "" { model = "nomic-embed-text" } // Model dimensions dims := 768 // nomic-embed-text default switch model { case "mxbai-embed-large": dims = 1024 case "all-minilm": dims = 384 } return &OllamaEmbedder{ baseURL: baseURL, model: model, dims: dims, } } func (e *OllamaEmbedder) Dimensions() int { return e.dims } type ollamaRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` } type ollamaResponse struct { Embedding []float32 `json:"embedding"` } func (e *OllamaEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) { embeddings := make([][]float32, len(texts)) // Ollama's /api/embeddings takes one prompt at a time for i, text := range texts { req := ollamaRequest{ Model: e.model, Prompt: text, } body, err := json.Marshal(req) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/api/embeddings", bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("ollama request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("ollama returned status %d", resp.StatusCode) } var result ollamaResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, err } embeddings[i] = result.Embedding } return embeddings, nil } // OpenAIEmbedder uses OpenAI-compatible embedding API type OpenAIEmbedder struct { baseURL string apiKey string model string dims int } // NewOpenAIEmbedder creates an OpenAI-compatible embedder func NewOpenAIEmbedder(model string) *OpenAIEmbedder { baseURL := os.Getenv("CODEVEC_BASE_URL") if baseURL == "" { baseURL = "https://api.openai.com" } apiKey := os.Getenv("CODEVEC_API_KEY") if model == "" { model = "text-embedding-3-small" } dims := 1536 // text-embedding-3-small default switch model { case "text-embedding-3-large": dims = 3072 case "text-embedding-ada-002": dims = 1536 } return &OpenAIEmbedder{ baseURL: baseURL, apiKey: apiKey, model: model, dims: dims, } } func (e *OpenAIEmbedder) Dimensions() int { return e.dims } type openaiRequest struct { Model string `json:"model"` Input []string `json:"input"` } type openaiResponse struct { Data []struct { Embedding []float32 `json:"embedding"` } `json:"data"` } func (e *OpenAIEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) { if e.apiKey == "" { return nil, fmt.Errorf("CODEVEC_API_KEY not set") } // Batch in groups of 100 const batchSize = 100 embeddings := make([][]float32, len(texts)) for start := 0; start < len(texts); start += batchSize { end := start + batchSize if end > len(texts) { end = len(texts) } batch := texts[start:end] req := openaiRequest{ Model: e.model, Input: batch, } body, err := json.Marshal(req) if err != nil { return nil, err } httpReq, err := http.NewRequestWithContext(ctx, "POST", e.baseURL+"/v1/embeddings", bytes.NewReader(body)) if err != nil { return nil, err } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+e.apiKey) resp, err := http.DefaultClient.Do(httpReq) if err != nil { return nil, fmt.Errorf("openai request failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("openai returned status %d", resp.StatusCode) } var result openaiResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { return nil, err } for i, d := range result.Data { embeddings[start+i] = d.Embedding } } return embeddings, nil } // New creates an embedder based on provider name func New(provider, model string) (Embedder, error) { switch provider { case "ollama": return NewOllamaEmbedder(model), nil case "openai": return NewOpenAIEmbedder(model), nil default: return nil, fmt.Errorf("unknown provider: %s", provider) } }