feat: add custom OpenAI-compatible provider with custom headers and parameters
- Add Custom provider for any OpenAI-compatible API - All providers now support custom headers and parameters - providerSettings stored per-provider with headers and parameters - Example: disable thinking in Ollama, add tools, etc.
This commit is contained in:
@@ -447,10 +447,11 @@ CRITICAL RULES:
|
|||||||
const baseUrl = providerConfig.baseUrl || settings?.aiBaseUrl;
|
const baseUrl = providerConfig.baseUrl || settings?.aiBaseUrl;
|
||||||
|
|
||||||
const aiProvider = createAIProvider({
|
const aiProvider = createAIProvider({
|
||||||
provider: provider as 'openai' | 'anthropic' | 'ollama' | 'lmstudio' | 'groq',
|
provider,
|
||||||
apiKey: apiKey,
|
apiKey: apiKey,
|
||||||
model: model || undefined,
|
model: model || undefined,
|
||||||
baseUrl: (provider === 'ollama' || provider === 'lmstudio') ? baseUrl : undefined,
|
baseUrl: (provider === 'ollama' || provider === 'lmstudio' || provider === 'custom') ? baseUrl : undefined,
|
||||||
|
providerSettings: settings?.providerSettings,
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log(`[Journal Generate] AI Provider created: ${aiProvider.provider}`);
|
console.log(`[Journal Generate] AI Provider created: ${aiProvider.provider}`);
|
||||||
@@ -657,7 +658,7 @@ app.post('/api/v1/ai/test', async (c) => {
|
|||||||
const userId = await getUserId(c);
|
const userId = await getUserId(c);
|
||||||
if (!userId) return c.json({ data: null, error: { code: 'UNAUTHORIZED', message: 'Invalid API key' } }, 401);
|
if (!userId) return c.json({ data: null, error: { code: 'UNAUTHORIZED', message: 'Invalid API key' } }, 401);
|
||||||
|
|
||||||
const { provider, apiKey, model, baseUrl } = await c.req.json();
|
const { provider, apiKey, model, baseUrl, headers, parameters } = await c.req.json();
|
||||||
|
|
||||||
console.log(`[AI Test] Provider: ${provider}, Model: ${model || 'default'}, BaseURL: ${baseUrl || 'default'}`);
|
console.log(`[AI Test] Provider: ${provider}, Model: ${model || 'default'}, BaseURL: ${baseUrl || 'default'}`);
|
||||||
console.log(`[AI Test] API Key set: ${!!apiKey}, Length: ${apiKey?.length || 0}`);
|
console.log(`[AI Test] API Key set: ${!!apiKey}, Length: ${apiKey?.length || 0}`);
|
||||||
@@ -666,12 +667,18 @@ app.post('/api/v1/ai/test', async (c) => {
|
|||||||
return c.json({ data: null, error: { code: 'VALIDATION_ERROR', message: 'provider is required' } }, 400);
|
return c.json({ data: null, error: { code: 'VALIDATION_ERROR', message: 'provider is required' } }, 400);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let providerSettings: string | undefined;
|
||||||
|
if (headers || parameters) {
|
||||||
|
providerSettings = JSON.stringify({ headers, parameters });
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const aiProvider = createAIProvider({
|
const aiProvider = createAIProvider({
|
||||||
provider,
|
provider,
|
||||||
apiKey: apiKey || '',
|
apiKey: apiKey || '',
|
||||||
model: model || undefined,
|
model: model || undefined,
|
||||||
baseUrl: baseUrl || undefined,
|
baseUrl: baseUrl || undefined,
|
||||||
|
providerSettings,
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log(`[AI Test] Creating provider: ${aiProvider.provider}`);
|
console.log(`[AI Test] Creating provider: ${aiProvider.provider}`);
|
||||||
|
|||||||
@@ -28,10 +28,11 @@ journalRoutes.post('/generate/:date', async (c) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const provider = createAIProvider({
|
const provider = createAIProvider({
|
||||||
provider: settings.aiProvider as AIProvider['provider'],
|
provider: settings.aiProvider,
|
||||||
apiKey: settings.aiApiKey,
|
apiKey: settings.aiApiKey,
|
||||||
model: settings.aiModel,
|
model: settings.aiModel,
|
||||||
baseUrl: settings.aiBaseUrl,
|
baseUrl: settings.aiBaseUrl,
|
||||||
|
providerSettings: settings.providerSettings,
|
||||||
});
|
});
|
||||||
|
|
||||||
const eventsText = events.map(event => {
|
const eventsText = events.map(event => {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ settingsRoutes.put('/', async (c) => {
|
|||||||
const prisma = c.get('prisma');
|
const prisma = c.get('prisma');
|
||||||
|
|
||||||
const body = await c.req.json();
|
const body = await c.req.json();
|
||||||
const { aiProvider, aiApiKey, aiModel, aiBaseUrl, journalPrompt, language } = body;
|
const { aiProvider, aiApiKey, aiModel, aiBaseUrl, journalPrompt, language, providerSettings } = body;
|
||||||
|
|
||||||
const data: Record<string, unknown> = {};
|
const data: Record<string, unknown> = {};
|
||||||
if (aiProvider !== undefined) data.aiProvider = aiProvider;
|
if (aiProvider !== undefined) data.aiProvider = aiProvider;
|
||||||
@@ -35,6 +35,7 @@ settingsRoutes.put('/', async (c) => {
|
|||||||
if (aiBaseUrl !== undefined) data.aiBaseUrl = aiBaseUrl;
|
if (aiBaseUrl !== undefined) data.aiBaseUrl = aiBaseUrl;
|
||||||
if (journalPrompt !== undefined) data.journalPrompt = journalPrompt;
|
if (journalPrompt !== undefined) data.journalPrompt = journalPrompt;
|
||||||
if (language !== undefined) data.language = language;
|
if (language !== undefined) data.language = language;
|
||||||
|
if (providerSettings !== undefined) data.providerSettings = typeof providerSettings === 'string' ? providerSettings : JSON.stringify(providerSettings);
|
||||||
|
|
||||||
const settings = await prisma.settings.upsert({
|
const settings = await prisma.settings.upsert({
|
||||||
where: { userId },
|
where: { userId },
|
||||||
@@ -70,12 +71,17 @@ settingsRoutes.post('/validate-key', async (c) => {
|
|||||||
|
|
||||||
settingsRoutes.post('/test', async (c) => {
|
settingsRoutes.post('/test', async (c) => {
|
||||||
const body = await c.req.json();
|
const body = await c.req.json();
|
||||||
const { provider, apiKey, model, baseUrl } = body;
|
const { provider, apiKey, model, baseUrl, headers, parameters } = body;
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
return c.json({ data: null, error: { code: 'VALIDATION_ERROR', message: 'provider is required' } }, 400);
|
return c.json({ data: null, error: { code: 'VALIDATION_ERROR', message: 'provider is required' } }, 400);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let providerSettings: string | undefined;
|
||||||
|
if (headers || parameters) {
|
||||||
|
providerSettings = JSON.stringify({ headers, parameters });
|
||||||
|
}
|
||||||
|
|
||||||
const { createAIProvider } = await import('../services/ai/provider');
|
const { createAIProvider } = await import('../services/ai/provider');
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -84,11 +90,12 @@ settingsRoutes.post('/test', async (c) => {
|
|||||||
apiKey: apiKey || '',
|
apiKey: apiKey || '',
|
||||||
model: model || undefined,
|
model: model || undefined,
|
||||||
baseUrl: baseUrl || undefined,
|
baseUrl: baseUrl || undefined,
|
||||||
|
providerSettings,
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await aiProvider.generate('Say "OK" if you can read this.', 'You are a test assistant. Respond with just "OK".');
|
const result = await aiProvider.generate('Say "OK" if you can read this.', 'You are a test assistant. Respond with just "OK".');
|
||||||
|
|
||||||
if (result.toLowerCase().includes('ok')) {
|
if (result.content.toLowerCase().includes('ok')) {
|
||||||
return c.json({ data: { valid: true, message: 'Connection successful!' }, error: null });
|
return c.json({ data: { valid: true, message: 'Connection successful!' }, error: null });
|
||||||
} else {
|
} else {
|
||||||
return c.json({ data: { valid: false }, error: { code: 'TEST_FAILED', message: 'Model responded but with unexpected output' } });
|
return c.json({ data: { valid: false }, error: { code: 'TEST_FAILED', message: 'Model responded but with unexpected output' } });
|
||||||
|
|||||||
@@ -1,15 +1,20 @@
|
|||||||
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
||||||
|
import { ProviderSettings } from './openai';
|
||||||
|
|
||||||
export class AnthropicProvider implements AIProvider {
|
export class AnthropicProvider implements AIProvider {
|
||||||
provider = 'anthropic' as const;
|
provider = 'anthropic' as const;
|
||||||
private apiKey: string;
|
private apiKey: string;
|
||||||
private model: string;
|
private model: string;
|
||||||
private baseUrl: string;
|
private baseUrl: string;
|
||||||
|
private customHeaders?: Record<string, string>;
|
||||||
|
private customParameters?: Record<string, unknown>;
|
||||||
|
|
||||||
constructor(config: AIProviderConfig) {
|
constructor(config: AIProviderConfig, settings?: ProviderSettings) {
|
||||||
this.apiKey = config.apiKey;
|
this.apiKey = config.apiKey;
|
||||||
this.model = config.model || 'claude-3-sonnet-20240229';
|
this.model = config.model || 'claude-3-sonnet-20240229';
|
||||||
this.baseUrl = config.baseUrl || 'https://api.anthropic.com/v1';
|
this.baseUrl = config.baseUrl || 'https://api.anthropic.com/v1';
|
||||||
|
this.customHeaders = settings?.headers;
|
||||||
|
this.customParameters = settings?.parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
||||||
@@ -20,20 +25,24 @@ export class AnthropicProvider implements AIProvider {
|
|||||||
messages: [
|
messages: [
|
||||||
{ role: 'user', content: prompt }
|
{ role: 'user', content: prompt }
|
||||||
],
|
],
|
||||||
|
...this.customParameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options?.jsonMode) {
|
if (options?.jsonMode) {
|
||||||
requestBody.output = { format: { type: 'json_object' } };
|
requestBody.output = { format: { type: 'json_object' } };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'x-api-key': this.apiKey,
|
||||||
|
'anthropic-version': '2023-06-01',
|
||||||
|
'anthropic-dangerous-direct-browser-access': 'true',
|
||||||
|
...this.customHeaders,
|
||||||
|
};
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/messages`, {
|
const response = await fetch(`${this.baseUrl}/messages`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'x-api-key': this.apiKey,
|
|
||||||
'anthropic-version': '2023-06-01',
|
|
||||||
'anthropic-dangerous-direct-browser-access': 'true',
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -66,18 +75,21 @@ export class AnthropicProvider implements AIProvider {
|
|||||||
|
|
||||||
async validate(): Promise<boolean> {
|
async validate(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'x-api-key': this.apiKey,
|
||||||
|
'anthropic-version': '2023-06-01',
|
||||||
|
'anthropic-dangerous-direct-browser-access': 'true',
|
||||||
|
...this.customHeaders,
|
||||||
|
};
|
||||||
const response = await fetch(`${this.baseUrl}/messages`, {
|
const response = await fetch(`${this.baseUrl}/messages`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'x-api-key': this.apiKey,
|
|
||||||
'anthropic-version': '2023-06-01',
|
|
||||||
'anthropic-dangerous-direct-browser-access': 'true',
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
max_tokens: 1,
|
max_tokens: 1,
|
||||||
messages: [{ role: 'user', content: 'hi' }],
|
messages: [{ role: 'user', content: 'hi' }],
|
||||||
|
...this.customParameters,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
return response.ok;
|
return response.ok;
|
||||||
|
|||||||
94
backend/src/services/ai/custom.ts
Normal file
94
backend/src/services/ai/custom.ts
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
||||||
|
import { ProviderSettings } from './openai';
|
||||||
|
|
||||||
|
export class CustomProvider implements AIProvider {
|
||||||
|
provider = 'custom' as const;
|
||||||
|
private apiKey: string;
|
||||||
|
private model: string;
|
||||||
|
private baseUrl: string;
|
||||||
|
private customHeaders?: Record<string, string>;
|
||||||
|
private customParameters?: Record<string, unknown>;
|
||||||
|
|
||||||
|
constructor(config: AIProviderConfig, settings?: ProviderSettings) {
|
||||||
|
this.apiKey = config.apiKey || '';
|
||||||
|
this.model = config.model || 'unknown';
|
||||||
|
this.baseUrl = config.baseUrl || '';
|
||||||
|
this.customHeaders = settings?.headers;
|
||||||
|
this.customParameters = settings?.parameters;
|
||||||
|
}
|
||||||
|
|
||||||
|
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
||||||
|
const messages: Array<{ role: string; content: string }> = [];
|
||||||
|
|
||||||
|
if (systemPrompt) {
|
||||||
|
messages.push({ role: 'system', content: systemPrompt });
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push({ role: 'user', content: prompt });
|
||||||
|
|
||||||
|
const requestBody: Record<string, unknown> = {
|
||||||
|
model: this.model,
|
||||||
|
messages,
|
||||||
|
temperature: 0.7,
|
||||||
|
max_tokens: 2000,
|
||||||
|
...this.customParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (options?.jsonMode) {
|
||||||
|
requestBody.response_format = { type: 'json_object' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
...(this.apiKey ? { 'Authorization': `Bearer ${this.apiKey}` } : {}),
|
||||||
|
...this.customHeaders,
|
||||||
|
};
|
||||||
|
|
||||||
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers,
|
||||||
|
body: JSON.stringify(requestBody),
|
||||||
|
});
|
||||||
|
|
||||||
|
const responseData = await response.json();
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Custom API error: ${response.status} ${JSON.stringify(responseData)}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
let content = responseData.choices?.[0]?.message?.content || '';
|
||||||
|
let title: string | undefined;
|
||||||
|
|
||||||
|
if (options?.jsonMode) {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(content);
|
||||||
|
title = parsed.title;
|
||||||
|
content = parsed.content || content;
|
||||||
|
} catch {
|
||||||
|
// If JSON parsing fails, use content as-is
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
content,
|
||||||
|
title,
|
||||||
|
request: requestBody,
|
||||||
|
response: responseData,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
async validate(): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
if (!this.baseUrl) return false;
|
||||||
|
const response = await fetch(`${this.baseUrl}/models`, {
|
||||||
|
headers: {
|
||||||
|
...(this.apiKey ? { 'Authorization': `Bearer ${this.apiKey}` } : {}),
|
||||||
|
...this.customHeaders,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
return response.ok;
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,15 +1,20 @@
|
|||||||
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
||||||
|
import { ProviderSettings } from './openai';
|
||||||
|
|
||||||
export class GroqProvider implements AIProvider {
|
export class GroqProvider implements AIProvider {
|
||||||
provider = 'groq' as const;
|
provider = 'groq' as const;
|
||||||
private apiKey: string;
|
private apiKey: string;
|
||||||
private model: string;
|
private model: string;
|
||||||
private baseUrl: string;
|
private baseUrl: string;
|
||||||
|
private customHeaders?: Record<string, string>;
|
||||||
|
private customParameters?: Record<string, unknown>;
|
||||||
|
|
||||||
constructor(config: AIProviderConfig) {
|
constructor(config: AIProviderConfig, settings?: ProviderSettings) {
|
||||||
this.apiKey = config.apiKey;
|
this.apiKey = config.apiKey;
|
||||||
this.model = config.model || 'llama-3.3-70b-versatile';
|
this.model = config.model || 'llama-3.3-70b-versatile';
|
||||||
this.baseUrl = config.baseUrl || 'https://api.groq.com/openai/v1';
|
this.baseUrl = config.baseUrl || 'https://api.groq.com/openai/v1';
|
||||||
|
this.customHeaders = settings?.headers;
|
||||||
|
this.customParameters = settings?.parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
||||||
@@ -26,18 +31,22 @@ export class GroqProvider implements AIProvider {
|
|||||||
messages,
|
messages,
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
max_tokens: 2000,
|
max_tokens: 2000,
|
||||||
|
...this.customParameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options?.jsonMode) {
|
if (options?.jsonMode) {
|
||||||
requestBody.response_format = { type: 'json_object' };
|
requestBody.response_format = { type: 'json_object' };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': `Bearer ${this.apiKey}`,
|
||||||
|
...this.customHeaders,
|
||||||
|
};
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${this.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -71,16 +80,19 @@ export class GroqProvider implements AIProvider {
|
|||||||
|
|
||||||
async validate(): Promise<boolean> {
|
async validate(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': `Bearer ${this.apiKey}`,
|
||||||
|
...this.customHeaders,
|
||||||
|
};
|
||||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${this.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
model: this.model,
|
model: this.model,
|
||||||
messages: [{ role: 'user', content: 'test' }],
|
messages: [{ role: 'user', content: 'test' }],
|
||||||
max_tokens: 5,
|
max_tokens: 5,
|
||||||
|
...this.customParameters,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
return response.ok || response.status === 400;
|
return response.ok || response.status === 400;
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
||||||
|
import { ProviderSettings } from './openai';
|
||||||
|
|
||||||
export class LMStudioProvider implements AIProvider {
|
export class LMStudioProvider implements AIProvider {
|
||||||
provider = 'lmstudio' as const;
|
provider = 'lmstudio' as const;
|
||||||
private baseUrl: string;
|
private baseUrl: string;
|
||||||
private model: string;
|
private model: string;
|
||||||
|
private customHeaders?: Record<string, string>;
|
||||||
|
private customParameters?: Record<string, unknown>;
|
||||||
|
|
||||||
constructor(config: AIProviderConfig) {
|
constructor(config: AIProviderConfig, settings?: ProviderSettings) {
|
||||||
this.baseUrl = config.baseUrl || 'http://localhost:1234/v1';
|
this.baseUrl = config.baseUrl || 'http://localhost:1234/v1';
|
||||||
this.model = config.model || 'local-model';
|
this.model = config.model || 'local-model';
|
||||||
|
this.customHeaders = settings?.headers;
|
||||||
|
this.customParameters = settings?.parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
||||||
@@ -19,18 +24,26 @@ export class LMStudioProvider implements AIProvider {
|
|||||||
|
|
||||||
messages.push({ role: 'user', content: prompt });
|
messages.push({ role: 'user', content: prompt });
|
||||||
|
|
||||||
const requestBody = {
|
const requestBody: Record<string, unknown> = {
|
||||||
model: this.model,
|
model: this.model,
|
||||||
messages,
|
messages,
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
max_tokens: 2000,
|
max_tokens: 2000,
|
||||||
|
...this.customParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
if (options?.jsonMode) {
|
||||||
|
requestBody.response_format = { type: 'json_object' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
...this.customHeaders,
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -63,7 +76,9 @@ export class LMStudioProvider implements AIProvider {
|
|||||||
|
|
||||||
async validate(): Promise<boolean> {
|
async validate(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${this.baseUrl}/models`);
|
const response = await fetch(`${this.baseUrl}/models`, {
|
||||||
|
headers: this.customHeaders,
|
||||||
|
});
|
||||||
return response.ok;
|
return response.ok;
|
||||||
} catch {
|
} catch {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -1,30 +1,39 @@
|
|||||||
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
||||||
|
import { ProviderSettings } from './openai';
|
||||||
|
|
||||||
export class OllamaProvider implements AIProvider {
|
export class OllamaProvider implements AIProvider {
|
||||||
provider = 'ollama' as const;
|
provider = 'ollama' as const;
|
||||||
private baseUrl: string;
|
private baseUrl: string;
|
||||||
private model: string;
|
private model: string;
|
||||||
|
private customHeaders?: Record<string, string>;
|
||||||
|
private customParameters?: Record<string, unknown>;
|
||||||
|
|
||||||
constructor(config: AIProviderConfig) {
|
constructor(config: AIProviderConfig, settings?: ProviderSettings) {
|
||||||
this.baseUrl = config.baseUrl || 'http://localhost:11434';
|
this.baseUrl = config.baseUrl || 'http://localhost:11434';
|
||||||
this.model = config.model || 'llama3.2';
|
this.model = config.model || 'llama3.2';
|
||||||
|
this.customHeaders = settings?.headers;
|
||||||
|
this.customParameters = settings?.parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
||||||
const requestBody = {
|
const requestBody: Record<string, unknown> = {
|
||||||
model: this.model,
|
model: this.model,
|
||||||
stream: false,
|
stream: false,
|
||||||
messages: [
|
messages: [
|
||||||
...(systemPrompt ? [{ role: 'system', content: systemPrompt }] : []),
|
...(systemPrompt ? [{ role: 'system', content: systemPrompt }] : []),
|
||||||
{ role: 'user', content: prompt },
|
{ role: 'user', content: prompt },
|
||||||
],
|
],
|
||||||
|
...this.customParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
...this.customHeaders,
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/api/chat`, {
|
const response = await fetch(`${this.baseUrl}/api/chat`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -57,7 +66,9 @@ export class OllamaProvider implements AIProvider {
|
|||||||
|
|
||||||
async validate(): Promise<boolean> {
|
async validate(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${this.baseUrl}/api/tags`);
|
const response = await fetch(`${this.baseUrl}/api/tags`, {
|
||||||
|
headers: this.customHeaders,
|
||||||
|
});
|
||||||
return response.ok;
|
return response.ok;
|
||||||
} catch {
|
} catch {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
import type { AIProvider, AIProviderConfig, AIProviderResult } from './provider';
|
||||||
|
|
||||||
|
export interface ProviderSettings {
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
parameters?: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
export class OpenAIProvider implements AIProvider {
|
export class OpenAIProvider implements AIProvider {
|
||||||
provider = 'openai' as const;
|
provider = 'openai' as const;
|
||||||
private apiKey: string;
|
private apiKey: string;
|
||||||
private model: string;
|
private model: string;
|
||||||
private baseUrl: string;
|
private baseUrl: string;
|
||||||
|
private customHeaders?: Record<string, string>;
|
||||||
|
private customParameters?: Record<string, unknown>;
|
||||||
|
|
||||||
constructor(config: AIProviderConfig) {
|
constructor(config: AIProviderConfig, settings?: ProviderSettings) {
|
||||||
this.apiKey = config.apiKey;
|
this.apiKey = config.apiKey;
|
||||||
this.model = config.model || 'gpt-4';
|
this.model = config.model || 'gpt-4';
|
||||||
this.baseUrl = config.baseUrl || 'https://api.openai.com/v1';
|
this.baseUrl = config.baseUrl || 'https://api.openai.com/v1';
|
||||||
|
this.customHeaders = settings?.headers;
|
||||||
|
this.customParameters = settings?.parameters;
|
||||||
}
|
}
|
||||||
|
|
||||||
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
async generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult> {
|
||||||
@@ -26,18 +35,22 @@ export class OpenAIProvider implements AIProvider {
|
|||||||
messages,
|
messages,
|
||||||
temperature: 0.7,
|
temperature: 0.7,
|
||||||
max_tokens: 2000,
|
max_tokens: 2000,
|
||||||
|
...this.customParameters,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (options?.jsonMode) {
|
if (options?.jsonMode) {
|
||||||
requestBody.response_format = { type: 'json_object' };
|
requestBody.response_format = { type: 'json_object' };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': `Bearer ${this.apiKey}`,
|
||||||
|
...this.customHeaders,
|
||||||
|
};
|
||||||
|
|
||||||
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
const response = await fetch(`${this.baseUrl}/chat/completions`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers,
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'Authorization': `Bearer ${this.apiKey}`,
|
|
||||||
},
|
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(requestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -70,11 +83,11 @@ export class OpenAIProvider implements AIProvider {
|
|||||||
|
|
||||||
async validate(): Promise<boolean> {
|
async validate(): Promise<boolean> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${this.baseUrl}/models`, {
|
const headers: Record<string, string> = {
|
||||||
headers: {
|
'Authorization': `Bearer ${this.apiKey}`,
|
||||||
'Authorization': `Bearer ${this.apiKey}`,
|
...this.customHeaders,
|
||||||
},
|
};
|
||||||
});
|
const response = await fetch(`${this.baseUrl}/models`, { headers });
|
||||||
return response.ok;
|
return response.ok;
|
||||||
} catch {
|
} catch {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@@ -6,16 +6,33 @@ export interface AIProviderResult {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface AIProvider {
|
export interface AIProvider {
|
||||||
provider: 'openai' | 'anthropic' | 'ollama' | 'lmstudio' | 'groq';
|
provider: string;
|
||||||
generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult>;
|
generate(prompt: string, systemPrompt?: string, options?: { jsonMode?: boolean }): Promise<AIProviderResult>;
|
||||||
validate?(): Promise<boolean>;
|
validate?(): Promise<boolean>;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AIProviderConfig {
|
export interface AIProviderConfig {
|
||||||
provider: 'openai' | 'anthropic' | 'ollama' | 'lmstudio' | 'groq';
|
provider: string;
|
||||||
apiKey: string;
|
apiKey?: string;
|
||||||
model?: string;
|
model?: string;
|
||||||
baseUrl?: string;
|
baseUrl?: string;
|
||||||
|
providerSettings?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseProviderSettings(settingsJson?: string): {
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
parameters?: Record<string, unknown>;
|
||||||
|
} {
|
||||||
|
if (!settingsJson) return {};
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(settingsJson);
|
||||||
|
return {
|
||||||
|
headers: parsed.headers,
|
||||||
|
parameters: parsed.parameters,
|
||||||
|
};
|
||||||
|
} catch {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
import { OpenAIProvider } from './openai';
|
import { OpenAIProvider } from './openai';
|
||||||
@@ -23,19 +40,24 @@ import { AnthropicProvider } from './anthropic';
|
|||||||
import { OllamaProvider } from './ollama';
|
import { OllamaProvider } from './ollama';
|
||||||
import { LMStudioProvider } from './lmstudio';
|
import { LMStudioProvider } from './lmstudio';
|
||||||
import { GroqProvider } from './groq';
|
import { GroqProvider } from './groq';
|
||||||
|
import { CustomProvider } from './custom';
|
||||||
|
|
||||||
export function createAIProvider(config: AIProviderConfig): AIProvider {
|
export function createAIProvider(config: AIProviderConfig): AIProvider {
|
||||||
|
const settings = parseProviderSettings(config.providerSettings);
|
||||||
|
|
||||||
switch (config.provider) {
|
switch (config.provider) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
return new OpenAIProvider(config);
|
return new OpenAIProvider(config, settings);
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
return new AnthropicProvider(config);
|
return new AnthropicProvider(config, settings);
|
||||||
case 'ollama':
|
case 'ollama':
|
||||||
return new OllamaProvider(config);
|
return new OllamaProvider(config, settings);
|
||||||
case 'lmstudio':
|
case 'lmstudio':
|
||||||
return new LMStudioProvider(config);
|
return new LMStudioProvider(config, settings);
|
||||||
case 'groq':
|
case 'groq':
|
||||||
return new GroqProvider(config);
|
return new GroqProvider(config, settings);
|
||||||
|
case 'custom':
|
||||||
|
return new CustomProvider(config, settings);
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unknown AI provider: ${config.provider}`);
|
throw new Error(`Unknown AI provider: ${config.provider}`);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user